Compare commits
300 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a9cbca1604 | ||
|
|
3a30ce6c16 | ||
|
|
48ec5355f9 | ||
|
|
11859bc322 | ||
|
|
28c67a5be8 | ||
|
|
44fe93e9b0 | ||
|
|
09a1681b63 | ||
|
|
f5ba2190fb | ||
|
|
14a38b5069 | ||
|
|
f23e5b602a | ||
|
|
857696ed9c | ||
|
|
2084133058 | ||
|
|
f7f0c3070e | ||
|
|
46235aa514 | ||
|
|
2eb65d21ac | ||
|
|
37a0d62a82 | ||
|
|
21ac46e439 | ||
|
|
ba3e8ba20c | ||
|
|
2c48e798ca | ||
|
|
4e40f5b62b | ||
|
|
2a8892b785 | ||
|
|
ee3b33ff03 | ||
|
|
b2c3001f8e | ||
|
|
6cfe1e1ac2 | ||
|
|
52326870e4 | ||
|
|
217fde0918 | ||
|
|
065021d82a | ||
|
|
4bb643e685 | ||
|
|
b77c745b1a | ||
|
|
7d13501b94 | ||
|
|
ac74639b32 | ||
|
|
12fa56ae68 | ||
|
|
f11b863f4b | ||
|
|
f3e4b72957 | ||
|
|
8d52fb46ca | ||
|
|
dab8f45033 | ||
|
|
bff8b02543 | ||
|
|
2406200914 | ||
|
|
db06fcfc84 | ||
|
|
93b9f74e9f | ||
|
|
33ec844f76 | ||
|
|
0f727b393e | ||
|
|
7da2aad6ee | ||
|
|
6f09f50d02 | ||
|
|
5919832059 | ||
|
|
f7635c1afc | ||
|
|
c762168ed0 | ||
|
|
67a46e553f | ||
|
|
e406f37b54 | ||
|
|
62fe877124 | ||
|
|
a0e682ba79 | ||
|
|
49e8a87383 | ||
|
|
b2764b49ca | ||
|
|
06b810de8f | ||
|
|
6da51565f5 | ||
|
|
1f69965239 | ||
|
|
af2d61178d | ||
|
|
6a955ccf4f | ||
|
|
c0658711ca | ||
|
|
d602f06882 | ||
|
|
1cb9a38ac2 | ||
|
|
47a1f73d0f | ||
|
|
142dd63b47 | ||
|
|
b1bd8370c2 | ||
|
|
215660c8da | ||
|
|
0cafe67efe | ||
|
|
ea83b3222b | ||
|
|
725087a04f | ||
|
|
d627ab4855 | ||
|
|
7d867e8df4 | ||
|
|
3d34d44497 | ||
|
|
a6f800b741 | ||
|
|
a003d1fa1e | ||
|
|
c2e84d4558 | ||
|
|
68330eab2a | ||
|
|
7070f3969d | ||
|
|
e4727ab155 | ||
|
|
280e7d97ad | ||
|
|
31e3805fb8 | ||
|
|
ef248dbe15 | ||
|
|
6a61b4b638 | ||
|
|
4b1473502f | ||
|
|
bf211d818d | ||
|
|
27dd87c890 | ||
|
|
8659084ab0 | ||
|
|
e1c9dcea93 | ||
|
|
171339ab17 | ||
|
|
8542ba5c69 | ||
|
|
97b74d328b | ||
|
|
3198a7e5f4 | ||
|
|
a2d08ce961 | ||
|
|
bd8ea09479 | ||
|
|
6d0d46c7fb | ||
|
|
820540780a | ||
|
|
f74d600497 | ||
|
|
94fec9f50e | ||
|
|
e387a50475 | ||
|
|
5c4248a29c | ||
|
|
f22886e2b6 | ||
|
|
33af3cbf37 | ||
|
|
728dfb1be7 | ||
|
|
e49f7f1afe | ||
|
|
21a454fa6c | ||
|
|
22c6c27f78 | ||
|
|
aecbb43096 | ||
|
|
fa53fd2db2 | ||
|
|
1c150995ae | ||
|
|
6c5d8f089e | ||
|
|
dd623325e8 | ||
|
|
e8a375c8f2 | ||
|
|
386d85ae72 | ||
|
|
ebb3901b05 | ||
|
|
20130b486c | ||
|
|
73c48d0463 | ||
|
|
f7cecd20e3 | ||
|
|
2bc64a7636 | ||
|
|
9564ddbb48 | ||
|
|
28062c71b5 | ||
|
|
35d1921081 | ||
|
|
4fbdf18c70 | ||
|
|
5e07ab01f0 | ||
|
|
fac465a21e | ||
|
|
e145a2ce0c | ||
|
|
dc68c313ee | ||
|
|
95c0d9ab24 | ||
|
|
46a718f339 | ||
|
|
496ba46960 | ||
|
|
43ae0aca1d | ||
|
|
b8574c1b82 | ||
|
|
32f8b1082b | ||
|
|
6443fef31a | ||
|
|
14c3795a7d | ||
|
|
3d9e2de573 | ||
|
|
0ca36a0f8d | ||
|
|
3e5555502a | ||
|
|
fbf5b5e0a9 | ||
|
|
3305e66f8c | ||
|
|
e19a44c12b | ||
|
|
8b0e6b9d1b | ||
|
|
f3e638ac6a | ||
|
|
42e0b30476 | ||
|
|
a09a7b650d | ||
|
|
332d7bbd56 | ||
|
|
d3b6fece71 | ||
|
|
9d963b82de | ||
|
|
a402161631 | ||
|
|
b481ad58e6 | ||
|
|
f91c5f2638 | ||
|
|
7143c551ab | ||
|
|
50e93392dd | ||
|
|
9f83e93839 | ||
|
|
692b132dbf | ||
|
|
e70b3e8947 | ||
|
|
612d97db6f | ||
|
|
bb1b67c076 | ||
|
|
5a75c31caa | ||
|
|
8b9210286b | ||
|
|
b5acec34f7 | ||
|
|
86d835878c | ||
|
|
eae7b331d3 | ||
|
|
ed89e29bcc | ||
|
|
c2b1886aff | ||
|
|
218f36bca5 | ||
|
|
b91fc1f5b3 | ||
|
|
2a22bf9c15 | ||
|
|
62e2037125 | ||
|
|
e5b72c6a77 | ||
|
|
93be211f80 | ||
|
|
9ae3fb4ced | ||
|
|
f641075789 | ||
|
|
f7658db1b6 | ||
|
|
b869bc1a20 | ||
|
|
a72d756d77 | ||
|
|
d3fd8f89b8 | ||
|
|
180a05a446 | ||
|
|
eb9ac9ee1f | ||
|
|
a6662b73f5 | ||
|
|
cbc7db3478 | ||
|
|
4606340f0f | ||
|
|
d4b4ccd597 | ||
|
|
9c3f4e3a37 | ||
|
|
440e00d8f9 | ||
|
|
6310613699 | ||
|
|
f55907dbea | ||
|
|
5cac87d317 | ||
|
|
9c0622de13 | ||
|
|
37b93c8b71 | ||
|
|
d6be98cda6 | ||
|
|
4d128acc17 | ||
|
|
516df9ecce | ||
|
|
8eec1d50e1 | ||
|
|
cfb096d43a | ||
|
|
713fa28804 | ||
|
|
5549f35939 | ||
|
|
6eed1db36c | ||
|
|
948124f55e | ||
|
|
2b191ca776 | ||
|
|
be4d2822ea | ||
|
|
736ddd0319 | ||
|
|
dfa289aa72 | ||
|
|
c2644f939a | ||
|
|
f11c1ae562 | ||
|
|
3126164aa6 | ||
|
|
ed10486cad | ||
|
|
04fa430c6c | ||
|
|
fa1893b59c | ||
|
|
e993e717a5 | ||
|
|
c80e56423a | ||
|
|
ffa09a01d6 | ||
|
|
7d04f8567b | ||
|
|
baa709674f | ||
|
|
ca9a494d0c | ||
|
|
37eb8c05cc | ||
|
|
7c046edb7b | ||
|
|
22cea38b20 | ||
|
|
ef2ca0a827 | ||
|
|
7f0b908de2 | ||
|
|
5fc5e776ff | ||
|
|
93b281c016 | ||
|
|
9585699918 | ||
|
|
bceaba551d | ||
|
|
0bfeed3a7e | ||
|
|
70a780c3c0 | ||
|
|
d74ab5306c | ||
|
|
688e8601ab | ||
|
|
4933ab5956 | ||
|
|
6c7225a5d4 | ||
|
|
a22982f2fa | ||
|
|
c95479dddb | ||
|
|
fc48bd8da0 | ||
|
|
d5323bfa3f | ||
|
|
e9d4a2b507 | ||
|
|
37bcbe8046 | ||
|
|
fdfb644f0a | ||
|
|
cde9f3db57 | ||
|
|
8bf5a98815 | ||
|
|
be566a15a5 | ||
|
|
d5f1b99ac4 | ||
|
|
2144bb0e27 | ||
|
|
bc665bacc7 | ||
|
|
52bfcf4883 | ||
|
|
06df3d6fb6 | ||
|
|
ca719a8697 | ||
|
|
72dfd74005 | ||
|
|
69302c4420 | ||
|
|
42d7019b2e | ||
|
|
5f0d0d6b9b | ||
|
|
76cb63e4f6 | ||
|
|
467d571206 | ||
|
|
972bfa700a | ||
|
|
458955d0fb | ||
|
|
990eeccf45 | ||
|
|
a3a7465f00 | ||
|
|
031a819257 | ||
|
|
eb4b4e3c8c | ||
|
|
d2e1fe9b1d | ||
|
|
6e27a9e39a | ||
|
|
805478c911 | ||
|
|
a281cdeb89 | ||
|
|
cda698a67f | ||
|
|
15acd17716 | ||
|
|
34a2bddfcd | ||
|
|
370f817549 | ||
|
|
041390c37e | ||
|
|
d9fe4bf500 | ||
|
|
e0c7e944fc | ||
|
|
0845fe67db | ||
|
|
fe3b12d900 | ||
|
|
a70d56864e | ||
|
|
fdbb2c5378 | ||
|
|
3c0aaf42af | ||
|
|
438e19160a | ||
|
|
f2b2ff6950 | ||
|
|
86cef96305 | ||
|
|
5f50944baf | ||
|
|
0804fd2353 | ||
|
|
86419eb457 | ||
|
|
76f3ae7bf3 | ||
|
|
aaa85190eb | ||
|
|
e2a4e926b9 | ||
|
|
d6e922dc1c | ||
|
|
27f4317ec6 | ||
|
|
e434348216 | ||
|
|
2e19afedb8 | ||
|
|
da08fa7c63 | ||
|
|
9c96b97dc7 | ||
|
|
28a51b622b | ||
|
|
8bd1da7144 | ||
|
|
e4d0b8ee6e | ||
|
|
1dfb28b362 | ||
|
|
ba618947e7 | ||
|
|
f81041b502 | ||
|
|
f2533a2800 | ||
|
|
bb5b4a7f26 | ||
|
|
20bff87021 | ||
|
|
722b954800 | ||
|
|
19256086c7 | ||
|
|
250fecfcd4 | ||
|
|
cb4d1d5ebb | ||
|
|
d7d557fb2e |
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -1,3 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
||||
*.json filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
160
.gitignore
vendored
Normal file
160
.gitignore
vendored
Normal file
@@ -0,0 +1,160 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
456
README.md
456
README.md
@@ -1,102 +1,148 @@
|
||||
# LLaMA Efficient Tuning
|
||||
# LLaMA Factory: Training and Evaluating Large Language Models with Minimal Effort
|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
|
||||
👋 Join our [WeChat](assets/wechat.jpg).
|
||||
|
||||
\[ English | [中文](README_zh.md) \]
|
||||
|
||||
## LLaMA Board: A One-stop Web UI for Getting Started with LLaMA Factory
|
||||
|
||||
Launch **LLaMA Board** via `CUDA_VISIBLE_DEVICES=0 python src/train_web.py`. (multiple GPUs are not supported yet)
|
||||
|
||||
Here is an example of altering the self-cognition of an instruction-tuned language model within 10 minutes on a single GPU.
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## Changelog
|
||||
|
||||
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset.
|
||||
[23/10/21] We supported **[NEFTune](https://arxiv.org/abs/2310.05914)** trick for fine-tuning. Try `--neft_alpha` argument to activate NEFTune, e.g., `--neft_alpha 5`.
|
||||
|
||||
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
|
||||
[23/09/27] We supported **$S^2$-Attn** proposed by [LongLoRA](https://github.com/dvlab-research/LongLoRA) for the LLaMA models. Try `--shift_attn` argument to enable shift short attention.
|
||||
|
||||
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
|
||||
[23/09/23] We integrated MMLU, C-Eval and CMMLU benchmarks in this repo. See [this example](#evaluation) to evaluate your models.
|
||||
|
||||
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
[23/09/10] We supported using **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)** for the LLaMA models. Try `--flash_attn` argument to enable FlashAttention-2 if you are using RTX4090, A100 or H100 GPUs.
|
||||
|
||||
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
|
||||
[23/08/12] We supported **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
|
||||
|
||||
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
[23/08/11] We supported **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models.
|
||||
|
||||
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
|
||||
[23/07/31] We supported **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
|
||||
|
||||
[23/07/05] Now we support training the **Falcon-7B/40B** models in this repo. Try `--model_name_or_path tiiuae/falcon-7b` and `--lora_target query_key_value` arguments to use the Falcon model.
|
||||
[23/07/29] We released two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft)) for details.
|
||||
|
||||
[23/06/29] We provide a **reproducible example** of training a chat model using instruction-following datasets, see this [Hugging Face Repo](https://huggingface.co/hiyouga/baichuan-7b-sft) for details.
|
||||
[23/07/18] We developed an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
|
||||
|
||||
[23/06/22] Now we align the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
[23/07/09] We released **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
|
||||
|
||||
[23/06/15] Now we support training the **Baichuan-7B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-7B` and `--lora_target W_pack` arguments to use the Baichuan-7B model.
|
||||
[23/06/29] We provided a **reproducible example** of training a chat model using instruction-following datasets, see [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft) for details.
|
||||
|
||||
[23/06/03] Now we support quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
[23/06/22] We aligned the [demo API](src/api_demo.py) with the [OpenAI's](https://platform.openai.com/docs/api-reference/chat) format where you can insert the fine-tuned model in **arbitrary ChatGPT-based applications**.
|
||||
|
||||
[23/05/31] Now we support training the **BLOOM & BLOOMZ** models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` and `--lora_target query_key_value` arguments to use the BLOOMZ model.
|
||||
[23/06/03] We supported quantized training and inference (aka **[QLoRA](https://github.com/artidoro/qlora)**). Try `--quantization_bit 4/8` argument to work with quantized models.
|
||||
|
||||
## Supported Models
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||
- [InternLM](https://github.com/InternLM/InternLM) (7B)
|
||||
| Model | Model size | Default module | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
> [!NOTE]
|
||||
> **Default module** is used for the `--lora_target` argument, you can use `--lora_target all` to specify all the available modules.
|
||||
>
|
||||
> For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the **corresponding template** for the "chat" models.
|
||||
|
||||
Please refer to [template.py](src/llmtuner/extras/template.py) for a full list of models we supported.
|
||||
|
||||
## Supported Training Approaches
|
||||
|
||||
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- Full-parameter tuning
|
||||
- Partial-parameter tuning
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
|
||||
- Full-parameter tuning
|
||||
- Partial-parameter tuning
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [RLHF](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
| Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| Reward Modeling | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO Training | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> Use `--quantization_bit 4/8` argument to enable QLoRA.
|
||||
|
||||
## Provided Datasets
|
||||
|
||||
- For pre-training:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- For supervised fine-tuning:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- For reward modelling:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
<details><summary>Pre-training datasets</summary>
|
||||
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Supervised fine-tuning datasets</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Preference datasets</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
</details>
|
||||
|
||||
Please refer to [data/README.md](data/README.md) for details.
|
||||
|
||||
@@ -111,9 +157,10 @@ huggingface-cli login
|
||||
|
||||
- Python 3.8+ and PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL
|
||||
- jieba, rouge-chinese and nltk (used at evaluation)
|
||||
- gradio and matplotlib (used in web_demo.py)
|
||||
- uvicorn, fastapi and sse-starlette (used in api_demo.py)
|
||||
- sentencepiece, protobuf and tiktoken
|
||||
- fire, jieba, rouge-chinese and nltk (used at evaluation and predict)
|
||||
- gradio and matplotlib (used in web UI)
|
||||
- uvicorn, fastapi and sse-starlette (used in API)
|
||||
|
||||
And **powerful GPUs**!
|
||||
|
||||
@@ -121,18 +168,18 @@ And **powerful GPUs**!
|
||||
|
||||
### Data Preparation (optional)
|
||||
|
||||
Please refer to `data/example_dataset` for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
||||
Please refer to [data/README.md](data/README.md) for checking the details about the format of dataset files. You can either use a single `.json` file or a [dataset loading script](https://huggingface.co/docs/datasets/dataset_script) with multiple files to create a custom dataset.
|
||||
|
||||
Note: please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
||||
> [!NOTE]
|
||||
> Please update `data/dataset_info.json` to use your custom dataset. About the format of this file, please refer to `data/README.md`.
|
||||
|
||||
### Dependence Installation (optional)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
cd LLaMA-Efficient-Tuning
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
conda create -n llama_factory python=3.10
|
||||
conda activate llama_factory
|
||||
cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -142,24 +189,21 @@ If you want to enable the quantized LoRA (QLoRA) on the Windows platform, you wi
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### All-in-one Web UI
|
||||
### Train on a single GPU
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
```
|
||||
> [!IMPORTANT]
|
||||
> If you want to train models on multiple GPUs, please refer to [Distributed Training](#distributed-training).
|
||||
|
||||
Currently the web UI only supports training on a single GPU.
|
||||
|
||||
### (Continually) Pre-Training
|
||||
#### Pre-Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
@@ -173,16 +217,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### Supervised Fine-Tuning
|
||||
#### Supervised Fine-Tuning
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_sft_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
@@ -196,40 +241,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### Reward Model Training
|
||||
#### Reward Modeling
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--learning_rate 1e-6 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### PPO Training (RLHF)
|
||||
#### PPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
@@ -241,29 +288,51 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
#### DPO Training
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_en \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### Distributed Training
|
||||
|
||||
#### Use Huggingface Accelerate
|
||||
|
||||
```bash
|
||||
accelerate config # configure the environment
|
||||
accelerate launch src/train_bash.py # arguments (same as above)
|
||||
```
|
||||
|
||||
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary>
|
||||
<details><summary>Example config for LoRA training</summary>
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 0.5
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
@@ -279,31 +348,110 @@ use_cpu: false
|
||||
|
||||
</details>
|
||||
|
||||
### Evaluation (BLEU and ROUGE_CHINESE)
|
||||
#### Use DeepSpeed
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_en \
|
||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
--deepspeed ds_config.json \
|
||||
... # arguments (same as above)
|
||||
```
|
||||
|
||||
<details><summary>Example config for full-parameter training with DeepSpeed ZeRO-2</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"contiguous_gradients": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit evaluation.
|
||||
### API Demo
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task mmlu \
|
||||
--split test \
|
||||
--lang en \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
### Predict
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--dataset alpaca_gpt4_en \
|
||||
--template default \
|
||||
@@ -315,85 +463,39 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
### API Demo
|
||||
> [!NOTE]
|
||||
> We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` at 4/8-bit predict.
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
## Projects using LLaMA Factory
|
||||
|
||||
Visit `http://localhost:8000/docs` for API documentation.
|
||||
|
||||
### CLI Demo
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Web Demo
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### Export model
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
|
||||
- [ ] Implementing multi-query attention for faster inference.
|
||||
- [ ] Supporting full-parameter RLHF training.
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: A large language model for Astronomy, based on ChatGLM2-6B and Qwen-14B.
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: A large language model specialized in Chinese legal domain, based on Baichuan-13B, is capable of retrieving and reasoning on legal knowledge.
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: A large language model specialized in Chinese medical domain, based on Baichuan-7B and ChatGLM-6B.
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: A series of large language models for Chinese medical domain, based on LLaMA2-7B and Baichuan-13B.
|
||||
|
||||
## License
|
||||
|
||||
This repository is licensed under the [Apache-2.0 License](LICENSE).
|
||||
|
||||
Please follow the model licenses to use the corresponding model weights:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
Please follow the model licenses to use the corresponding model weights: [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
|
||||
## Citation
|
||||
|
||||
If this work is helpful, please kindly cite as:
|
||||
|
||||
```bibtex
|
||||
@Misc{llama-efficient-tuning,
|
||||
title = {LLaMA Efficient Tuning},
|
||||
@Misc{llama-factory,
|
||||
title = {LLaMA Factory},
|
||||
author = {hiyouga},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This repo is a sibling of [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning). They share a similar code structure of efficient tuning on large language models.
|
||||
This repo benefits from [PEFT](https://github.com/huggingface/peft), [QLoRA](https://github.com/artidoro/qlora) and [FastChat](https://github.com/lm-sys/FastChat). Thanks for their wonderful works.
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||

|
||||
|
||||
453
README_zh.md
453
README_zh.md
@@ -1,104 +1,150 @@
|
||||
# LLaMA Efficient Tuning
|
||||
# LLaMA Factory: 轻松的大模型训练与评估
|
||||
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/commits/main)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/stargazers)
|
||||
[](LICENSE)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/commits/main)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Efficient-Tuning/pulls)
|
||||
[](https://pypi.org/project/llmtuner/)
|
||||
[](https://github.com/hiyouga/LLaMA-Factory/pulls)
|
||||
[](https://discord.gg/e73gccsSd)
|
||||
|
||||
👋 加入我们的[微信群](assets/wechat.jpg)。
|
||||
|
||||
\[ [English](README.md) | 中文 \]
|
||||
|
||||
## LLaMA Board: 通过一站式网页界面快速上手 LLaMA Factory
|
||||
|
||||
使用 `CUDA_VISIBLE_DEVICES=0 python src/train_web.py` 启动 **LLaMA Board**。(该界面目前仅支持单卡训练)
|
||||
|
||||
下面是使用单张 GPU 在 10 分钟内更改对话式大型语言模型自我认知的示例。
|
||||
|
||||
https://github.com/hiyouga/LLaMA-Factory/assets/16256802/6ba60acc-e2e2-4bec-b846-2d88920d5ba1
|
||||
|
||||
## 更新日志
|
||||
|
||||
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming` 和 `--max_steps 100` 参数来流式加载数据集。
|
||||
[23/10/21] 我们支持了 **[NEFTune](https://arxiv.org/abs/2310.05914)** 训练技巧。请使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。
|
||||
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
|
||||
[23/09/27] 我们针对 LLaMA 模型支持了 [LongLoRA](https://github.com/dvlab-research/LongLoRA) 提出的 **$S^2$-Attn**。请使用 `--shift_attn` 参数以启用该功能。
|
||||
|
||||
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。
|
||||
[23/09/23] 我们在项目中集成了 MMLU、C-Eval 和 CMMLU 评估集。使用方法请参阅[此示例](#模型评估)。
|
||||
|
||||
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||
[23/09/10] 我们针对 LLaMA 模型支持了 **[FlashAttention-2](https://github.com/Dao-AILab/flash-attention)**。如果您使用的是 RTX4090、A100 或 H100 GPU,请使用 `--flash_attn` 参数以启用 FlashAttention-2。
|
||||
|
||||
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model` 和 `--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。
|
||||
[23/08/12] 我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
|
||||
|
||||
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||
[23/08/11] 我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。使用方法请参阅[此示例](#dpo-训练)。
|
||||
|
||||
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。
|
||||
[23/07/31] 我们支持了**数据流式加载**。请尝试使用 `--streaming` 和 `--max_steps 10000` 参数来流式加载数据集。
|
||||
|
||||
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b` 和 `--lora_target query_key_value` 参数。
|
||||
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/Baichuan-13B-sft))。
|
||||
|
||||
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
|
||||
[23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
|
||||
|
||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。
|
||||
[23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
|
||||
|
||||
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B` 和 `--lora_target W_pack` 参数。
|
||||
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Baichuan-7B-sft](https://huggingface.co/hiyouga/Baichuan-7B-sft)。
|
||||
|
||||
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
|
||||
|
||||
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt` 和 `--lora_target query_key_value` 参数。
|
||||
[23/06/03] 我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
|
||||
|
||||
## 模型
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
|
||||
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B)
|
||||
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
|
||||
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B)
|
||||
- [InternLM](https://github.com/InternLM/InternLM) (7B)
|
||||
| 模型名 | 模型大小 | 默认模块 | Template |
|
||||
| -------------------------------------------------------- | --------------------------- | ----------------- | --------- |
|
||||
| [Baichuan](https://github.com/baichuan-inc/Baichuan-13B) | 7B/13B | W_pack | baichuan |
|
||||
| [Baichuan2](https://github.com/baichuan-inc/Baichuan2) | 7B/13B | W_pack | baichuan2 |
|
||||
| [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
|
||||
| [ChatGLM3](https://github.com/THUDM/ChatGLM3) | 6B | query_key_value | chatglm3 |
|
||||
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B/180B | query_key_value | - |
|
||||
| [InternLM](https://github.com/InternLM/InternLM) | 7B/20B | q_proj,v_proj | intern |
|
||||
| [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
|
||||
| [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
|
||||
| [Mistral](https://huggingface.co/mistralai) | 7B | q_proj,v_proj | mistral |
|
||||
| [Phi-1.5](https://huggingface.co/microsoft/phi-1_5) | 1.3B | Wqkv | - |
|
||||
| [Qwen](https://github.com/QwenLM/Qwen) | 7B/14B | c_attn | qwen |
|
||||
| [XVERSE](https://github.com/xverse-ai) | 7B/13B/65B | q_proj,v_proj | xverse |
|
||||
|
||||
## 微调方法
|
||||
> [!NOTE]
|
||||
> **默认模块**应作为 `--lora_target` 参数的默认值,可使用 `--lora_target all` 参数指定全部模块。
|
||||
>
|
||||
> 对于所有“基座”(Base)模型,`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”(Chat)模型请务必使用**对应的模板**。
|
||||
|
||||
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [指令监督微调](https://arxiv.org/abs/2109.01652)
|
||||
- 全参数微调
|
||||
- 部分参数微调
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
- [人类反馈的强化学习(RLHF)](https://arxiv.org/abs/2203.02155)
|
||||
- [LoRA](https://arxiv.org/abs/2106.09685)
|
||||
- [QLoRA](https://arxiv.org/abs/2305.14314)
|
||||
项目所支持模型的完整列表请参阅 [template.py](src/llmtuner/extras/template.py)。
|
||||
|
||||
## 训练方法
|
||||
|
||||
| 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
|
||||
| ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
|
||||
| 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
|
||||
| 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| PPO 训练 | | | :white_check_mark: | :white_check_mark: |
|
||||
| DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
|
||||
|
||||
> [!NOTE]
|
||||
> 请使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
|
||||
|
||||
## 数据集
|
||||
|
||||
- 用于二次预训练:
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- 用于指令监督微调:
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- 用于奖励模型训练:
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
<details><summary>预训练数据集</summary>
|
||||
|
||||
使用方法请参考 [data/README.md](data/README_zh.md) 文件。
|
||||
- [Wiki Demo (en)](data/wiki_demo.txt)
|
||||
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
|
||||
- [RedPajama V2 (en)](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-V2)
|
||||
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
|
||||
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
|
||||
- [Pile (en)](https://huggingface.co/datasets/EleutherAI/pile)
|
||||
- [SkyPile (zh)](https://huggingface.co/datasets/Skywork/SkyPile-150B)
|
||||
- [The Stack (en)](https://huggingface.co/datasets/bigcode/the-stack)
|
||||
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>指令微调数据集</summary>
|
||||
|
||||
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
|
||||
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
- [Self-cognition (zh)](data/self_cognition.json)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
|
||||
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
|
||||
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
|
||||
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
|
||||
- [BELLE 0.5M (zh)](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
|
||||
- [BELLE Dialogue 0.4M (zh)](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
|
||||
- [BELLE School Math 0.25M (zh)](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
|
||||
- [BELLE Multiturn Chat 0.8M (zh)](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
|
||||
- [UltraChat (en)](https://github.com/thunlp/UltraChat)
|
||||
- [LIMA (en)](https://huggingface.co/datasets/GAIR/lima)
|
||||
- [OpenPlatypus (en)](https://huggingface.co/datasets/garage-bAInd/Open-Platypus)
|
||||
- [CodeAlpaca 20k (en)](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
|
||||
- [Alpaca CoT (multilingual)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
|
||||
- [MathInstruct (en)](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)
|
||||
- [Firefly 1.1M (zh)](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
|
||||
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
|
||||
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
|
||||
- [Ad Gen (zh)](https://huggingface.co/datasets/HasturOfficial/adgen)
|
||||
- [ShareGPT Hyperfiltered (en)](https://huggingface.co/datasets/totally-not-an-llm/sharegpt-hyperfiltered-3k)
|
||||
- [ShareGPT4 (en&zh)](https://huggingface.co/datasets/shibing624/sharegpt_gpt4)
|
||||
- [UltraChat 200k (en)](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)
|
||||
- [AgentInstruct (en)](https://huggingface.co/datasets/THUDM/AgentInstruct)
|
||||
- [LMSYS Chat 1M (en)](https://huggingface.co/datasets/lmsys/lmsys-chat-1m)
|
||||
- [Evol Instruct V2 (en)](https://huggingface.co/datasets/WizardLM/WizardLM_evol_instruct_V2_196k)
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>偏好数据集</summary>
|
||||
|
||||
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
|
||||
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
|
||||
|
||||
</details>
|
||||
|
||||
使用方法请参考 [data/README_zh.md](data/README_zh.md) 文件。
|
||||
|
||||
部分数据集的使用需要确认,我们推荐使用下述命令登录您的 Hugging Face 账户。
|
||||
|
||||
@@ -111,7 +157,8 @@ huggingface-cli login
|
||||
|
||||
- Python 3.8+ 和 PyTorch 1.13.1+
|
||||
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
|
||||
- jieba, rouge-chinese 和 nltk (用于评估)
|
||||
- sentencepiece, protobuf 和 tiktoken
|
||||
- fire, jieba, rouge-chinese 和 nltk (用于评估及预测)
|
||||
- gradio 和 matplotlib (用于网页端交互)
|
||||
- uvicorn, fastapi 和 sse-starlette (用于 API)
|
||||
|
||||
@@ -121,18 +168,18 @@ huggingface-cli login
|
||||
|
||||
### 数据准备(可跳过)
|
||||
|
||||
关于数据集文件的格式,请参考 `data/example_dataset` 文件夹的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
||||
关于数据集文件的格式,请参考 [data/README_zh.md](data/README_zh.md) 的内容。构建自定义数据集时,既可以使用单个 `.json` 文件,也可以使用一个[数据加载脚本](https://huggingface.co/docs/datasets/dataset_script)和多个文件。
|
||||
|
||||
注意:使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README.md`。
|
||||
> [!NOTE]
|
||||
> 使用自定义数据集时,请更新 `data/dataset_info.json` 文件,该文件的格式请参考 `data/README_zh.md`。
|
||||
|
||||
### 环境搭建(可跳过)
|
||||
|
||||
```bash
|
||||
git lfs install
|
||||
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
|
||||
conda create -n llama_etuning python=3.10
|
||||
conda activate llama_etuning
|
||||
cd LLaMA-Efficient-Tuning
|
||||
git clone https://github.com/hiyouga/LLaMA-Factory.git
|
||||
conda create -n llama_factory python=3.10
|
||||
conda activate llama_factory
|
||||
cd LLaMA-Factory
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
@@ -142,24 +189,21 @@ pip install -r requirements.txt
|
||||
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
|
||||
```
|
||||
|
||||
### 浏览器一键微调/测试
|
||||
### 单 GPU 训练
|
||||
|
||||
```bash
|
||||
python src/train_web.py
|
||||
```
|
||||
> [!IMPORTANT]
|
||||
> 如果您使用多张 GPU 训练模型,请移步[多 GPU 分布式训练](#多-gpu-分布式训练)部分。
|
||||
|
||||
目前网页 UI 仅支持单卡训练。
|
||||
|
||||
### 二次预训练
|
||||
#### 预训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage pt \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset wiki_demo \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_pt_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
@@ -173,16 +217,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 指令监督微调
|
||||
#### 指令监督微调
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--output_dir path_to_sft_checkpoint \
|
||||
--overwrite_cache \
|
||||
--per_device_train_batch_size 4 \
|
||||
@@ -196,40 +241,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 奖励模型训练
|
||||
#### 奖励模型训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage rm \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_rm_checkpoint \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--learning_rate 1e-6 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### RLHF 训练
|
||||
#### PPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage ppo \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--reward_model path_to_rm_checkpoint \
|
||||
@@ -244,26 +291,47 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--plot_loss
|
||||
```
|
||||
|
||||
#### DPO 训练
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage dpo \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_train \
|
||||
--dataset comparison_gpt4_zh \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--lora_target q_proj,v_proj \
|
||||
--resume_lora_training False \
|
||||
--checkpoint_dir path_to_sft_checkpoint \
|
||||
--output_dir path_to_dpo_checkpoint \
|
||||
--per_device_train_batch_size 2 \
|
||||
--gradient_accumulation_steps 4 \
|
||||
--lr_scheduler_type cosine \
|
||||
--logging_steps 10 \
|
||||
--save_steps 1000 \
|
||||
--learning_rate 1e-5 \
|
||||
--num_train_epochs 1.0 \
|
||||
--plot_loss \
|
||||
--fp16
|
||||
```
|
||||
|
||||
### 多 GPU 分布式训练
|
||||
|
||||
#### 使用 Huggingface Accelerate
|
||||
|
||||
```bash
|
||||
accelerate config # 首先配置分布式环境
|
||||
accelerate launch src/train_bash.py # 参数同上
|
||||
```
|
||||
|
||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 Accelerate 配置示例</summary>
|
||||
<details><summary>LoRA 训练的 Accelerate 配置示例</summary>
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 4
|
||||
gradient_clipping: 0.5
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: fp16
|
||||
@@ -279,31 +347,110 @@ use_cpu: false
|
||||
|
||||
</details>
|
||||
|
||||
### 指标评估(BLEU分数和汉语ROUGE分数)
|
||||
#### 使用 DeepSpeed
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--do_eval \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
|
||||
--deepspeed ds_config.json \
|
||||
... # 参数同上
|
||||
```
|
||||
|
||||
<details><summary>使用 DeepSpeed ZeRO-2 进行全参数训练的 DeepSpeed 配置示例</summary>
|
||||
|
||||
```json
|
||||
{
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"gradient_clipping": "auto",
|
||||
"zero_allow_untested_optimizer": true,
|
||||
"fp16": {
|
||||
"enabled": "auto",
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 1000,
|
||||
"hysteresis": 2,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"contiguous_gradients": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### 导出微调后的完整模型
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_eval_result \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--max_samples 100 \
|
||||
--predict_with_generate
|
||||
--export_dir path_to_export
|
||||
```
|
||||
|
||||
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128` 参数。
|
||||
### API 服务
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> 关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 模型评估
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/evaluate.py \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--template vanilla \
|
||||
--task ceval \
|
||||
--split validation \
|
||||
--lang zh \
|
||||
--n_shot 5 \
|
||||
--batch_size 4
|
||||
```
|
||||
|
||||
### 模型预测
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--stage sft \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--model_name_or_path path_to_llama_model \
|
||||
--do_predict \
|
||||
--dataset alpaca_gpt4_zh \
|
||||
--template default \
|
||||
@@ -315,85 +462,39 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
|
||||
--predict_with_generate
|
||||
```
|
||||
|
||||
### API 服务
|
||||
> [!NOTE]
|
||||
> 我们建议在量化模型的预测中使用 `--per_device_eval_batch_size=1` 和 `--max_target_length 128`。
|
||||
|
||||
```bash
|
||||
python src/api_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
## 使用了 LLaMA Factory 的项目
|
||||
|
||||
关于 API 文档请见 `http://localhost:8000/docs`。
|
||||
|
||||
### 命令行测试
|
||||
|
||||
```bash
|
||||
python src/cli_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 浏览器测试
|
||||
|
||||
```bash
|
||||
python src/web_demo.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint
|
||||
```
|
||||
|
||||
### 导出微调模型
|
||||
|
||||
```bash
|
||||
python src/export_model.py \
|
||||
--model_name_or_path path_to_your_model \
|
||||
--template default \
|
||||
--finetuning_type lora \
|
||||
--checkpoint_dir path_to_checkpoint \
|
||||
--output_dir path_to_export
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] 实现 flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention))。
|
||||
- [ ] 在推理阶段使用 Multi-query attention 进行加速。
|
||||
- [ ] 支持 RLHF 的全参数微调。
|
||||
- **[StarWhisper](https://github.com/Yu-Yang-Li/StarWhisper)**: 天文大模型 StarWhisper,基于 ChatGLM2-6B 和 Qwen-14B 在天文数据上微调而得。
|
||||
- **[DISC-LawLLM](https://github.com/FudanDISC/DISC-LawLLM)**: 中文法律领域大模型 DISC-LawLLM,基于 Baichuan-13B 微调而得,具有法律推理和知识检索能力。
|
||||
- **[Sunsimiao](https://github.com/thomas-yanxin/Sunsimiao)**: 孙思邈中文医疗大模型 Sumsimiao,基于 Baichuan-7B 和 ChatGLM-6B 在中文医疗数据上微调而得。
|
||||
- **[CareGPT](https://github.com/WangRongsheng/CareGPT)**: 医疗大模型项目 CareGPT,基于 LLaMA2-7B 和 Baichuan-13B 在中文医疗数据上微调而得。
|
||||
|
||||
## 协议
|
||||
|
||||
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源。
|
||||
|
||||
使用模型权重时,请遵循对应的模型协议:
|
||||
|
||||
- [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md)
|
||||
- [LLaMA-2](https://ai.meta.com/llama/license/)
|
||||
- [BLOOM](https://huggingface.co/spaces/bigscience/license)
|
||||
- [Falcon](LICENSE)
|
||||
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
|
||||
- [InternLM](https://github.com/InternLM/InternLM#open-source-license)
|
||||
使用模型权重时,请遵循对应的模型协议:[Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Base/resolve/main/Community%20License%20for%20Baichuan-13B%20Model.pdf) / [Baichuan2](https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/resolve/main/Community%20License%20for%20Baichuan2%20Model.pdf) / [BLOOM](https://huggingface.co/spaces/bigscience/license) / [ChatGLM3](https://github.com/THUDM/ChatGLM3/blob/main/MODEL_LICENSE) / [Falcon](https://huggingface.co/tiiuae/falcon-180B/blob/main/LICENSE.txt) / [InternLM](https://github.com/InternLM/InternLM#license) / [LLaMA](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) / [LLaMA-2](https://ai.meta.com/llama/license/) / [Mistral](LICENSE) / [Phi-1.5](https://huggingface.co/microsoft/phi-1_5/resolve/main/Research%20License.docx) / [Qwen](https://github.com/QwenLM/Qwen/blob/main/LICENSE) / [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
|
||||
|
||||
## 引用
|
||||
|
||||
如果您觉得此项目有帮助,请考虑以下列格式引用
|
||||
|
||||
```bibtex
|
||||
@Misc{llama-efficient-tuning,
|
||||
title = {LLaMA Efficient Tuning},
|
||||
@Misc{llama-factory,
|
||||
title = {LLaMA Factory},
|
||||
author = {hiyouga},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Efficient-Tuning}},
|
||||
howpublished = {\url{https://github.com/hiyouga/LLaMA-Factory}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
## 致谢
|
||||
|
||||
本项目是 [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) 的同类项目。采用了类似的代码结构和训练方法。
|
||||
本项目受益于 [PEFT](https://github.com/huggingface/peft)、[QLoRA](https://github.com/artidoro/qlora) 和 [FastChat](https://github.com/lm-sys/FastChat),感谢以上诸位作者的付出。
|
||||
|
||||
## Star History
|
||||
|
||||

|
||||

|
||||
|
||||
103
data/README.md
103
data/README.md
@@ -2,17 +2,106 @@ If you are using a custom dataset, please provide your dataset definition in the
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"hf_hub_url": "the name of the dataset repository on the HuggingFace hub. (if specified, ignore below 3 arguments)",
|
||||
"hf_hub_url": "the name of the dataset repository on the Hugging Face hub. (if specified, ignore below 3 arguments)",
|
||||
"script_url": "the name of the directory containing a dataset loading script. (if specified, ignore below 2 arguments)",
|
||||
"file_name": "the name of the dataset file in the this directory. (required if above are not specified)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional)",
|
||||
"file_sha1": "the SHA-1 hash value of the dataset file. (optional, does not affect training)",
|
||||
"subset": "the name of the subset. (optional, default: None)",
|
||||
"ranking": "whether the dataset is a preference dataset or not. (default: false)",
|
||||
"formatting": "the format of the dataset. (optional, default: alpaca, can be chosen from {alpaca, sharegpt})",
|
||||
"columns": {
|
||||
"prompt": "the name of the column in the datasets containing the prompts. (default: instruction)",
|
||||
"query": "the name of the column in the datasets containing the queries. (default: input)",
|
||||
"response": "the name of the column in the datasets containing the responses. (default: output)",
|
||||
"history": "the name of the column in the datasets containing the history of chat. (default: None)"
|
||||
"prompt": "the column name in the dataset containing the prompts. (default: instruction, for alpaca)",
|
||||
"query": "the column name in the dataset containing the queries. (default: input, for alpaca)",
|
||||
"response": "the column name in the dataset containing the responses. (default: output, for alpaca)",
|
||||
"history": "the column name in the dataset containing the histories. (default: None, for alpaca)",
|
||||
"messages": "the column name in the dataset containing the messages. (default: conversations, for sharegpt)",
|
||||
"role": "the key in the message represents the identity. (default: from, for sharegpt)",
|
||||
"content": "the key in the message represents the content. (default: value, for sharegpt)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values. The `query` column will be concatenated with the `prompt` column and used as input for the model. The `history` column should contain a list where each element is a string tuple representing a query-response pair.
|
||||
Given above, you can use the custom dataset via specifying `--dataset dataset_name`.
|
||||
|
||||
Currently we support dataset in **alpaca** or **sharegpt** format, the dataset in alpaca format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "user instruction (required)",
|
||||
"input": "user input (optional)",
|
||||
"output": "model response (required)",
|
||||
"history": [
|
||||
["user instruction in the first round (optional)", "model response in the first round (optional)"],
|
||||
["user instruction in the second round (optional)", "model response in the second round (optional)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` and `response` columns should contain non-empty values, represent instruction and response respectively. The `query` column will be concatenated with the `prompt` column and used as input for the model.
|
||||
|
||||
The `history` column is a list consisting string tuples representing query-response pairs in history. Note that the responses **in each round will be used for training**.
|
||||
|
||||
For the pre-training datasets, only the `prompt` column will be used for training.
|
||||
|
||||
For the preference datasets, the `response` column should be a string list whose length is 2, with the preferred answers appearing first, for example:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "user instruction",
|
||||
"input": "user input",
|
||||
"output": [
|
||||
"chosen answer",
|
||||
"rejected answer"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The dataset in sharegpt format should follow the below format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "user instruction"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "model response"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Regarding the above dataset, the `columns` in `dataset_info.json` should be:
|
||||
|
||||
```json
|
||||
"dataset_name": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
where the `messages` column should be a list whose length is even, and follow the `u/a/u/a/u/a` order.
|
||||
|
||||
Pre-training datasets and preference datasets are incompatible with the sharegpt format yet.
|
||||
|
||||
@@ -1,18 +1,107 @@
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中以如下格式提供您的数据集定义。
|
||||
如果您使用自定义数据集,请务必在 `dataset_info.json` 文件中按照以下格式提供数据集定义。
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"hf_hub_url": "HuggingFace上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"hf_hub_url": "Hugging Face 上的项目地址(若指定,则忽略下列三个参数)",
|
||||
"script_url": "包含数据加载脚本的本地文件夹名称(若指定,则忽略下列两个参数)",
|
||||
"file_name": "该目录下数据集文件的名称(若上述参数未指定,则此项必需)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选)",
|
||||
"file_sha1": "数据集文件的SHA-1哈希值(可选,留空不影响训练)",
|
||||
"subset": "数据集子集的名称(可选,默认:None)",
|
||||
"ranking": "是否为偏好数据集(可选,默认:False)",
|
||||
"formatting": "数据集格式(可选,默认:alpaca,可以为 alpaca 或 sharegpt)",
|
||||
"columns": {
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction)",
|
||||
"query": "数据集代表请求的表头名称(默认:input)",
|
||||
"response": "数据集代表回答的表头名称(默认:output)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None)"
|
||||
"prompt": "数据集代表提示词的表头名称(默认:instruction,用于 alpaca 格式)",
|
||||
"query": "数据集代表请求的表头名称(默认:input,用于 alpaca 格式)",
|
||||
"response": "数据集代表回答的表头名称(默认:output,用于 alpaca 格式)",
|
||||
"history": "数据集代表历史对话的表头名称(默认:None,用于 alpaca 格式)",
|
||||
"messages": "数据集代表消息列表的表头名称(默认:conversations,用于 sharegpt 格式)",
|
||||
"role": "消息中代表发送者身份的键名(默认:from,用于 sharegpt 格式)",
|
||||
"content": "消息中代表文本内容的键名(默认:value,用于 sharegpt 格式)"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。`history` 列应当是一个列表,其中每个元素是一个字符串二元组,分别代表用户请求和模型答复。
|
||||
添加后可通过指定 `--dataset 数据集名称` 参数使用自定义数据集。
|
||||
|
||||
该项目目前支持两种格式的数据集:**alpaca** 和 **sharegpt**,其中 alpaca 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"instruction": "用户指令(必填)",
|
||||
"input": "用户输入(选填)",
|
||||
"output": "模型回答(必填)",
|
||||
"history": [
|
||||
["第一轮指令(选填)", "第一轮回答(选填)"],
|
||||
["第二轮指令(选填)", "第二轮回答(选填)"]
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"columns": {
|
||||
"prompt": "instruction",
|
||||
"query": "input",
|
||||
"response": "output",
|
||||
"history": "history"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `prompt` 和 `response` 列应当是非空的字符串,分别代表用户指令和模型回答。`query` 列的内容将会和 `prompt` 列拼接作为模型输入。
|
||||
|
||||
`history` 列是由多个字符串二元组构成的列表,分别代表历史消息中每轮的指令和回答。注意每轮的模型回答**均会被用于训练**。
|
||||
|
||||
对于预训练数据集,仅 `prompt` 列中的内容会用于模型训练。
|
||||
|
||||
对于偏好数据集,`response` 列应当是一个长度为 2 的字符串列表,排在前面的代表更优的回答,例如:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction": "用户指令",
|
||||
"input": "用户输入",
|
||||
"output": [
|
||||
"优质回答",
|
||||
"劣质回答"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
而 sharegpt 格式的数据集按照以下方式组织:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"conversations": [
|
||||
{
|
||||
"from": "human",
|
||||
"value": "用户指令"
|
||||
},
|
||||
{
|
||||
"from": "gpt",
|
||||
"value": "模型回答"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
对于上述格式的数据,`dataset_info.json` 中的 `columns` 应为:
|
||||
|
||||
```json
|
||||
"数据集名称": {
|
||||
"columns": {
|
||||
"messages": "conversations",
|
||||
"role": "from",
|
||||
"content": "value"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
其中 `messages` 列必须为偶数长度的列表,且符合 `用户/模型/用户/模型/用户/模型` 的顺序。
|
||||
|
||||
预训练数据集和偏好数据集尚不支持 sharegpt 格式。
|
||||
|
||||
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_en_52k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
3779ddbc040543ab1834ef216c983d6fcc06cc9a
|
||||
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
1
data/alpaca_data_zh_51k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
fc9a6a3458caca2af8dafc6181773fe10c6d8657
|
||||
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_en.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
25508714b7879a1e5a6764ba7f979a980f549f1a
|
||||
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/alpaca_gpt4_data_zh.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
7cb6a7d11455bddc3d495750a2392683d775b184
|
||||
@@ -1,6 +1,5 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "BELLE multiturn chat dataset."
|
||||
@@ -23,7 +22,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
@@ -37,7 +36,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download(_URL)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
@@ -48,7 +47,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath: str) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat with history
|
||||
def _generate_examples(self, filepath: str):
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for key, row in enumerate(f):
|
||||
data = json.loads(row)
|
||||
|
||||
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_en.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
f5cb08305ff5dc9c17a09809c54c8c8834aadc70
|
||||
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
1
data/comparison_gpt4_data_zh.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
aee47b7b443496e37808d7f34ef10403ff99bcc3
|
||||
@@ -3,7 +3,7 @@ import datasets
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
_DESCRIPTION = "An example of dataset for LLaMA."
|
||||
_DESCRIPTION = "An example of dataset."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = ""
|
||||
_LICENSE = ""
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness for ChatGLM."
|
||||
_DESCRIPTION = "Human preference data about helpfulness and harmlessness."
|
||||
_CITATION = ""
|
||||
_HOMEPAGE = "https://huggingface.co/datasets/Anthropic/hh-rlhf"
|
||||
_LICENSE = "mit"
|
||||
@@ -42,7 +42,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_path = dl_manager.download_and_extract(_URLS)
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
@@ -59,7 +59,7 @@ class HhRlhfEn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
key = 0
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
|
||||
1
data/oaast_rm.json.REMOVED.git-id
Normal file
1
data/oaast_rm.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
274079ea921762be356de85b18f13fa60b7ba8cb
|
||||
1
data/oaast_sft.json.REMOVED.git-id
Normal file
1
data/oaast_sft.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
57fd080be5bffe4153fe3ee26a175e3d56da30f3
|
||||
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
1
data/sharegpt_zh_27k.json.REMOVED.git-id
Normal file
@@ -0,0 +1 @@
|
||||
38c89869c6aeca2a3af9ea1e09afe460f9b46810
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import datasets
|
||||
from typing import Any, Dict, List
|
||||
from typing import List
|
||||
|
||||
|
||||
_DESCRIPTION = "UltraChat: Large-scale, Informative, and Diverse Multi-round Dialogue Data."
|
||||
@@ -21,15 +21,13 @@ _LICENSE = "cc-by-nc-4.0"
|
||||
_BASE_DATA_URL = "https://huggingface.co/datasets/stingning/ultrachat/resolve/main/train_{idx}.jsonl"
|
||||
|
||||
|
||||
class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
class UltraChat(datasets.GeneratorBasedBuilder):
|
||||
|
||||
VERSION = datasets.Version("0.0.0")
|
||||
|
||||
def _info(self) -> datasets.DatasetInfo:
|
||||
def _info(self):
|
||||
features = datasets.Features({
|
||||
"instruction": datasets.Value("string"),
|
||||
"output": datasets.Value("string"),
|
||||
"history": datasets.Sequence(datasets.Sequence(datasets.Value("string")))
|
||||
"conversations": [{"from": datasets.Value("string"), "value": datasets.Value("string")}]
|
||||
})
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
@@ -39,8 +37,8 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
citation=_CITATION
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager) -> List[datasets.SplitGenerator]:
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(9)] # multiple shards
|
||||
def _split_generators(self, dl_manager: datasets.DownloadManager):
|
||||
file_paths = [dl_manager.download(_BASE_DATA_URL.format(idx=idx)) for idx in range(10)] # multiple shards
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
@@ -50,7 +48,7 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepaths: List[str]) -> Dict[int, Dict[str, Any]]: # generate multi-turn chat for ChatGLM
|
||||
def _generate_examples(self, filepaths: List[str]):
|
||||
for filepath in filepaths:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
for row in f:
|
||||
@@ -58,19 +56,16 @@ class BelleMultiturn(datasets.GeneratorBasedBuilder):
|
||||
data = json.loads(row)
|
||||
except:
|
||||
continue
|
||||
key = data["id"]
|
||||
content = data["data"]
|
||||
key: int = data["id"]
|
||||
content: List[str] = data["data"]
|
||||
if len(content) % 2 == 1:
|
||||
content.pop(-1)
|
||||
if len(content) < 2:
|
||||
continue
|
||||
|
||||
query = content[-2]
|
||||
response = content[-1]
|
||||
history = [[content[2*i], content[2*i+1]] for i in range(len(content) // 2 - 1)]
|
||||
|
||||
conversations = [{
|
||||
"from": "human" if i % 2 == 0 else "gpt",
|
||||
"value": content[i]
|
||||
} for i in range(len(content))]
|
||||
yield key, {
|
||||
"instruction": query,
|
||||
"output": response,
|
||||
"history": history
|
||||
"conversations": conversations
|
||||
}
|
||||
|
||||
166
evaluation/ceval/ceval.py
Normal file
166
evaluation/ceval/ceval.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{huang2023ceval,
|
||||
title={C-Eval: A Multi-Level Multi-Discipline Chinese Evaluation Suite for Foundation Models},
|
||||
author={Huang, Yuzhen and Bai, Yuzhuo and Zhu, Zhihao and Zhang, Junlei and Zhang, Jinghan and Su, Tangjun and Liu, Junteng and Lv, Chuancheng and Zhang, Yikai and Lei, Jiayi and Fu, Yao and Sun, Maosong and He, Junxian},
|
||||
journal={arXiv preprint arXiv:2305.08322},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
C-Eval is a comprehensive Chinese evaluation suite for foundation models. It consists of 13948 multi-choice questions spanning 52 diverse disciplines and four difficulty levels.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://cevalbenchmark.com"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "ceval.zip"
|
||||
|
||||
task_list = [
|
||||
"computer_network",
|
||||
"operating_system",
|
||||
"computer_architecture",
|
||||
"college_programming",
|
||||
"college_physics",
|
||||
"college_chemistry",
|
||||
"advanced_mathematics",
|
||||
"probability_and_statistics",
|
||||
"discrete_mathematics",
|
||||
"electrical_engineer",
|
||||
"metrology_engineer",
|
||||
"high_school_mathematics",
|
||||
"high_school_physics",
|
||||
"high_school_chemistry",
|
||||
"high_school_biology",
|
||||
"middle_school_mathematics",
|
||||
"middle_school_biology",
|
||||
"middle_school_physics",
|
||||
"middle_school_chemistry",
|
||||
"veterinary_medicine",
|
||||
"college_economics",
|
||||
"business_administration",
|
||||
"marxism",
|
||||
"mao_zedong_thought",
|
||||
"education_science",
|
||||
"teacher_qualification",
|
||||
"high_school_politics",
|
||||
"high_school_geography",
|
||||
"middle_school_politics",
|
||||
"middle_school_geography",
|
||||
"modern_chinese_history",
|
||||
"ideological_and_moral_cultivation",
|
||||
"logic",
|
||||
"law",
|
||||
"chinese_language_and_literature",
|
||||
"art_studies",
|
||||
"professional_tour_guide",
|
||||
"legal_professional",
|
||||
"high_school_chinese",
|
||||
"high_school_history",
|
||||
"middle_school_history",
|
||||
"civil_servant",
|
||||
"sports_science",
|
||||
"plant_protection",
|
||||
"basic_medicine",
|
||||
"clinical_medicine",
|
||||
"urban_and_rural_planner",
|
||||
"accountant",
|
||||
"fire_engineer",
|
||||
"environmental_impact_assessment_engineer",
|
||||
"tax_accountant",
|
||||
"physician",
|
||||
]
|
||||
|
||||
|
||||
class CevalConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class Ceval(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CevalConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"id": datasets.Value("int32"),
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
"explanation": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
if "answer" not in instance.keys():
|
||||
instance["answer"] = ""
|
||||
if "explanation" not in instance.keys():
|
||||
instance["explanation"] = ""
|
||||
yield i, instance
|
||||
167
evaluation/cmmlu/cmmlu.py
Normal file
167
evaluation/cmmlu/cmmlu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{li2023cmmlu,
|
||||
title={CMMLU: Measuring massive multitask language understanding in Chinese},
|
||||
author={Haonan Li and Yixuan Zhang and Fajri Koto and Yifei Yang and Hai Zhao and Yeyun Gong and Nan Duan and Timothy Baldwin},
|
||||
journal={arXiv preprint arXiv:2306.09212},
|
||||
year={2023}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
CMMLU is a comprehensive Chinese assessment suite specifically designed to evaluate the advanced knowledge and reasoning abilities of LLMs within the Chinese language and cultural context.
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/haonan-li/CMMLU"
|
||||
|
||||
_LICENSE = "Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License"
|
||||
|
||||
_URL = "cmmlu.zip"
|
||||
|
||||
task_list = [
|
||||
'agronomy',
|
||||
'anatomy',
|
||||
'ancient_chinese',
|
||||
'arts',
|
||||
'astronomy',
|
||||
'business_ethics',
|
||||
'chinese_civil_service_exam',
|
||||
'chinese_driving_rule',
|
||||
'chinese_food_culture',
|
||||
'chinese_foreign_policy',
|
||||
'chinese_history',
|
||||
'chinese_literature',
|
||||
'chinese_teacher_qualification',
|
||||
'clinical_knowledge',
|
||||
'college_actuarial_science',
|
||||
'college_education',
|
||||
'college_engineering_hydrology',
|
||||
'college_law',
|
||||
'college_mathematics',
|
||||
'college_medical_statistics',
|
||||
'college_medicine',
|
||||
'computer_science',
|
||||
'computer_security',
|
||||
'conceptual_physics',
|
||||
'construction_project_management',
|
||||
'economics',
|
||||
'education',
|
||||
'electrical_engineering',
|
||||
'elementary_chinese',
|
||||
'elementary_commonsense',
|
||||
'elementary_information_and_technology',
|
||||
'elementary_mathematics',
|
||||
'ethnology',
|
||||
'food_science',
|
||||
'genetics',
|
||||
'global_facts',
|
||||
'high_school_biology',
|
||||
'high_school_chemistry',
|
||||
'high_school_geography',
|
||||
'high_school_mathematics',
|
||||
'high_school_physics',
|
||||
'high_school_politics',
|
||||
'human_sexuality',
|
||||
'international_law',
|
||||
'journalism',
|
||||
'jurisprudence',
|
||||
'legal_and_moral_basis',
|
||||
'logical',
|
||||
'machine_learning',
|
||||
'management',
|
||||
'marketing',
|
||||
'marxist_theory',
|
||||
'modern_chinese',
|
||||
'nutrition',
|
||||
'philosophy',
|
||||
'professional_accounting',
|
||||
'professional_law',
|
||||
'professional_medicine',
|
||||
'professional_psychology',
|
||||
'public_relations',
|
||||
'security_study',
|
||||
'sociology',
|
||||
'sports_science',
|
||||
'traditional_chinese_medicine',
|
||||
'virology',
|
||||
'world_history',
|
||||
'world_religions',
|
||||
]
|
||||
|
||||
|
||||
class CMMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.1"), **kwargs)
|
||||
|
||||
|
||||
class CMMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
CMMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"test/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(data_dir, f"dev/{task_name}.csv"),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath, header=0, index_col=0, encoding="utf-8")
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
question = instance.pop("Question", "")
|
||||
answer = instance.pop("Answer", "")
|
||||
instance["question"] = question
|
||||
instance["answer"] = answer
|
||||
yield i, instance
|
||||
167
evaluation/mmlu/mmlu.py
Normal file
167
evaluation/mmlu/mmlu.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
|
||||
|
||||
_CITATION = """\
|
||||
@article{hendryckstest2021,
|
||||
title={Measuring Massive Multitask Language Understanding},
|
||||
author={Dan Hendrycks and Collin Burns and Steven Basart and Andy Zou and Mantas Mazeika and Dawn Song and Jacob Steinhardt},
|
||||
journal={Proceedings of the International Conference on Learning Representations (ICLR)},
|
||||
year={2021}
|
||||
}
|
||||
"""
|
||||
|
||||
_DESCRIPTION = """\
|
||||
Measuring Massive Multitask Language Understanding by Dan Hendrycks, Collin Burns, Steven Basart, Andy Zou, Mantas Mazeika, Dawn Song, and Jacob Steinhardt (ICLR 2021).
|
||||
"""
|
||||
|
||||
_HOMEPAGE = "https://github.com/hendrycks/test"
|
||||
|
||||
_LICENSE = "MIT"
|
||||
|
||||
_URL = "mmlu.zip"
|
||||
|
||||
task_list = [
|
||||
"high_school_european_history",
|
||||
"business_ethics",
|
||||
"clinical_knowledge",
|
||||
"medical_genetics",
|
||||
"high_school_us_history",
|
||||
"high_school_physics",
|
||||
"high_school_world_history",
|
||||
"virology",
|
||||
"high_school_microeconomics",
|
||||
"econometrics",
|
||||
"college_computer_science",
|
||||
"high_school_biology",
|
||||
"abstract_algebra",
|
||||
"professional_accounting",
|
||||
"philosophy",
|
||||
"professional_medicine",
|
||||
"nutrition",
|
||||
"global_facts",
|
||||
"machine_learning",
|
||||
"security_studies",
|
||||
"public_relations",
|
||||
"professional_psychology",
|
||||
"prehistory",
|
||||
"anatomy",
|
||||
"human_sexuality",
|
||||
"college_medicine",
|
||||
"high_school_government_and_politics",
|
||||
"college_chemistry",
|
||||
"logical_fallacies",
|
||||
"high_school_geography",
|
||||
"elementary_mathematics",
|
||||
"human_aging",
|
||||
"college_mathematics",
|
||||
"high_school_psychology",
|
||||
"formal_logic",
|
||||
"high_school_statistics",
|
||||
"international_law",
|
||||
"high_school_mathematics",
|
||||
"high_school_computer_science",
|
||||
"conceptual_physics",
|
||||
"miscellaneous",
|
||||
"high_school_chemistry",
|
||||
"marketing",
|
||||
"professional_law",
|
||||
"management",
|
||||
"college_physics",
|
||||
"jurisprudence",
|
||||
"world_religions",
|
||||
"sociology",
|
||||
"us_foreign_policy",
|
||||
"high_school_macroeconomics",
|
||||
"computer_security",
|
||||
"moral_scenarios",
|
||||
"moral_disputes",
|
||||
"electrical_engineering",
|
||||
"astronomy",
|
||||
"college_biology",
|
||||
]
|
||||
|
||||
|
||||
class MMLUConfig(datasets.BuilderConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(version=datasets.Version("1.0.0"), **kwargs)
|
||||
|
||||
|
||||
class MMLU(datasets.GeneratorBasedBuilder):
|
||||
BUILDER_CONFIGS = [
|
||||
MMLUConfig(
|
||||
name=task_name,
|
||||
)
|
||||
for task_name in task_list
|
||||
]
|
||||
|
||||
def _info(self):
|
||||
features = datasets.Features(
|
||||
{
|
||||
"question": datasets.Value("string"),
|
||||
"A": datasets.Value("string"),
|
||||
"B": datasets.Value("string"),
|
||||
"C": datasets.Value("string"),
|
||||
"D": datasets.Value("string"),
|
||||
"answer": datasets.Value("string"),
|
||||
}
|
||||
)
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
features=features,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
citation=_CITATION,
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
data_dir = dl_manager.download_and_extract(_URL)
|
||||
task_name = self.config.name
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "test", f"{task_name}_test.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.VALIDATION,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "val", f"{task_name}_val.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"filepath": os.path.join(
|
||||
data_dir, "data", "dev", f"{task_name}_dev.csv"
|
||||
),
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, filepath):
|
||||
df = pd.read_csv(filepath)
|
||||
df.columns = ["question", "A", "B", "C", "D", "answer"]
|
||||
|
||||
for i, instance in enumerate(df.to_dict(orient="records")):
|
||||
yield i, instance
|
||||
@@ -1,16 +1,20 @@
|
||||
torch>=1.13.1
|
||||
transformers>=4.29.1
|
||||
datasets>=2.12.0
|
||||
transformers>=4.31.0,<4.35.0
|
||||
datasets>=2.14.0
|
||||
accelerate>=0.21.0
|
||||
peft>=0.4.0
|
||||
trl>=0.4.7
|
||||
peft>=0.6.0
|
||||
trl==0.7.2
|
||||
gradio>=3.38.0,<4.0.0
|
||||
scipy
|
||||
sentencepiece
|
||||
protobuf
|
||||
tiktoken
|
||||
fire
|
||||
jieba
|
||||
rouge-chinese
|
||||
nltk
|
||||
gradio>=3.36.0
|
||||
uvicorn
|
||||
pydantic==1.10.11
|
||||
fastapi==0.95.1
|
||||
pydantic
|
||||
fastapi
|
||||
sse-starlette
|
||||
matplotlib
|
||||
|
||||
4
setup.py
4
setup.py
@@ -25,12 +25,12 @@ def main():
|
||||
version=get_version(),
|
||||
author="hiyouga",
|
||||
author_email="hiyouga" "@" "buaa.edu.cn",
|
||||
description="Easy-to-use fine-tuning framework using PEFT",
|
||||
description="Easy-to-use LLM fine-tuning framework",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
keywords=["LLaMA", "BLOOM", "Falcon", "LLM", "ChatGPT", "transformer", "pytorch", "deep learning"],
|
||||
license="Apache 2.0 License",
|
||||
url="https://github.com/hiyouga/LLaMA-Efficient-Tuning",
|
||||
url="https://github.com/hiyouga/LLaMA-Factory",
|
||||
package_dir={"": "src"},
|
||||
packages=find_packages("src"),
|
||||
python_requires=">=3.8.0",
|
||||
|
||||
@@ -1,18 +1,12 @@
|
||||
# coding=utf-8
|
||||
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
|
||||
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
# Visit http://localhost:8000/docs for document.
|
||||
|
||||
import uvicorn
|
||||
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.api.app import create_app
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner import ChatModel, create_app
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
print("Visit http://localhost:8000/docs for API document.")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
# coding=utf-8
|
||||
# Implements stream chat in command line for fine-tuned models.
|
||||
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
import readline
|
||||
from llmtuner import ChatModel
|
||||
from llmtuner.tuner import get_infer_args
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
chat_model = ChatModel()
|
||||
history = []
|
||||
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")
|
||||
|
||||
|
||||
190
src/evaluate.py
Normal file
190
src/evaluate.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# coding=utf-8
|
||||
# Evaluates the performance of pre-trained models.
|
||||
# Usage: python evaluate.py --model_name_or_path path_to_model --checkpoint_dir path_to_ckpt --template vanilla
|
||||
# --task ceval --split validation --lang zh --n_shot 5 --batch_size 4 --save_name result
|
||||
# Inspired by: https://github.com/hendrycks/test/blob/master/evaluate_flan.py
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
import numpy as np
|
||||
import transformers
|
||||
from collections import Counter
|
||||
from datasets import load_dataset
|
||||
from dataclasses import dataclass
|
||||
from tqdm import tqdm, trange
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalTemplate:
|
||||
|
||||
system: str
|
||||
choice: str
|
||||
answer: str
|
||||
prefix: str
|
||||
|
||||
def parse_example(
|
||||
self,
|
||||
example: Dict[str, str]
|
||||
) -> Tuple[str, str]:
|
||||
candidates = [self.choice.format(choice=ch, content=example[ch]) for ch in choices if ch in example]
|
||||
return "".join([example["question"]] + candidates + [self.answer]), example["answer"]
|
||||
|
||||
def format_example(
|
||||
self,
|
||||
target_data: Dict[str, str],
|
||||
support_set: "Dataset",
|
||||
subject_name: str,
|
||||
use_history: bool
|
||||
) -> Tuple[str, str, List[Tuple[str, str]]]:
|
||||
query, resp = self.parse_example(target_data)
|
||||
history = [self.parse_example(support_set[k]) for k in range(len(support_set))]
|
||||
|
||||
if len(history):
|
||||
temp = history.pop(0)
|
||||
history.insert(0, (self.system.format(subject=subject_name) + temp[0], temp[1]))
|
||||
else:
|
||||
query = self.system.format(subject=subject_name) + query
|
||||
|
||||
if not use_history:
|
||||
query = "\n\n".join(["".join(item) for item in history] + [query])
|
||||
history = []
|
||||
return query.strip(), resp, history
|
||||
|
||||
|
||||
eval_templates = {
|
||||
"en": EvalTemplate(
|
||||
system="The following are multiple choice questions (with answers) about {subject}.\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\nAnswer: ",
|
||||
prefix=" "
|
||||
),
|
||||
"zh": EvalTemplate(
|
||||
system="以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n",
|
||||
choice="\n{choice}. {content}",
|
||||
answer="\n答案:",
|
||||
prefix="\n"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(
|
||||
chat_model: ChatModel,
|
||||
batch_input: Dict[str, torch.Tensor],
|
||||
prefix_char: str
|
||||
) -> List[str]:
|
||||
logits = chat_model.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
nextword_logits = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
probs = torch.nn.functional.softmax(
|
||||
torch.stack(
|
||||
[
|
||||
nextword_logits[:, chat_model.tokenizer.encode(prefix_char + choice, add_special_tokens=False)[-1]]
|
||||
for choice in choices
|
||||
],
|
||||
dim=-1
|
||||
),
|
||||
dim=-1
|
||||
).detach()
|
||||
return [chr(ord("A") + offset.item()) for offset in torch.argmax(probs, dim=-1)]
|
||||
|
||||
|
||||
def evaluate(
|
||||
model_name_or_path: str,
|
||||
finetuning_type: Optional[str] = "lora",
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
template: Optional[str] = "vanilla",
|
||||
task: Optional[str] = "ceval",
|
||||
dataset_dir: Optional[str] = "evaluation",
|
||||
split: Optional[Literal["validation", "test"]] = "validation",
|
||||
lang: Optional[Literal["zh", "en"]] = "zh",
|
||||
n_shot: Optional[int] = 5,
|
||||
n_avg: Optional[int] = 1,
|
||||
batch_size: Optional[int] = 4,
|
||||
save_name: Optional[str] = None,
|
||||
seed: Optional[int] = 42
|
||||
):
|
||||
with open(os.path.join(dataset_dir, task, "mapping.json"), "r", encoding="utf-8") as f:
|
||||
categorys: Dict[str, Dict[str, str]] = json.load(f)
|
||||
|
||||
transformers.set_seed(seed)
|
||||
chat_model = ChatModel(dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
template=template
|
||||
))
|
||||
chat_model.tokenizer.padding_side = "right" # avoid overflow issue in batched inference for llama2
|
||||
eval_template = eval_templates[lang]
|
||||
|
||||
category_corrects: Dict[str, np.ndarray] = {
|
||||
subj: np.array([], dtype="bool") for subj in ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
||||
}
|
||||
pbar = tqdm(categorys.keys(), desc="Processing subjects", position=0)
|
||||
results = {}
|
||||
for subject in pbar:
|
||||
dataset = load_dataset(os.path.join(dataset_dir, task), subject)
|
||||
labels, answers, all_outputs = [], [], []
|
||||
for epoch in range(n_avg):
|
||||
pbar.set_postfix_str("{} Trial: {}".format(categorys[subject]["name"], epoch))
|
||||
inputs, outputs = [], []
|
||||
for i in trange(len(dataset[split]), desc="Formatting batches", position=1, leave=False):
|
||||
support_set = dataset["train"].shuffle().select(range(min(n_shot, len(dataset["train"]))))
|
||||
query, resp, history = eval_template.format_example(
|
||||
target_data=dataset[split][i],
|
||||
support_set=support_set,
|
||||
subject_name=categorys[subject]["name"],
|
||||
use_history=chat_model.template.use_history
|
||||
)
|
||||
input_ids, _ = chat_model.template.encode_oneturn(
|
||||
tokenizer=chat_model.tokenizer, query=query, resp=resp, history=history
|
||||
)
|
||||
inputs.append({"input_ids": input_ids, "attention_mask": [1] * len(input_ids)})
|
||||
if epoch == 0:
|
||||
labels.append(resp)
|
||||
|
||||
for i in trange(0, len(inputs), batch_size, desc="Predicting batches", position=1, leave=False):
|
||||
batch_input = chat_model.tokenizer.pad(
|
||||
inputs[i : i + batch_size], return_attention_mask=True, return_tensors="pt"
|
||||
).to(chat_model.model.device)
|
||||
preds = batch_inference(chat_model, batch_input, eval_template.prefix)
|
||||
outputs += preds
|
||||
all_outputs.append(outputs)
|
||||
|
||||
for i in range(len(all_outputs[0])):
|
||||
count = Counter([all_outputs[epoch][i] for epoch in range(n_avg)])
|
||||
answers.append(count.most_common(1)[0][0])
|
||||
|
||||
corrects = (np.array(answers) == np.array(labels))
|
||||
category_name = categorys[subject]["category"]
|
||||
category_corrects[category_name] = np.concatenate([category_corrects[category_name], corrects], axis=0)
|
||||
category_corrects["Average"] = np.concatenate([category_corrects["Average"], corrects], axis=0)
|
||||
results[subject] = {str(i): answers[i] for i in range(len(answers))}
|
||||
|
||||
score_info = "\n".join([
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
for category_name, category_correct in category_corrects.items() if len(category_correct)
|
||||
])
|
||||
|
||||
print(score_info)
|
||||
if save_name is not None:
|
||||
with open(save_name + ".json", "w", encoding="utf-8", newline="\n") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
|
||||
with open(save_name + ".log", "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(score_info)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(evaluate)
|
||||
@@ -1,16 +1,8 @@
|
||||
# coding=utf-8
|
||||
# Exports the fine-tuned model.
|
||||
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
|
||||
|
||||
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
|
||||
from llmtuner import export_model
|
||||
|
||||
|
||||
def main():
|
||||
model_args, _, training_args, finetuning_args, _ = get_train_args()
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
print("model and tokenizer have been saved at:", training_args.output_dir)
|
||||
export_model()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
# Level: api, webui > chat > tuner > dsets > extras, hparams
|
||||
|
||||
from llmtuner.api import create_app
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.tuner import export_model, run_exp
|
||||
from llmtuner.webui import create_ui, create_web_demo
|
||||
|
||||
|
||||
__version__ = "0.1.4"
|
||||
__version__ = "0.2.1"
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
from llmtuner.api.app import create_app
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi import FastAPI, HTTPException, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from contextlib import asynccontextmanager
|
||||
from sse_starlette import EventSourceResponse
|
||||
from typing import List, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.chat import ChatModel
|
||||
from llmtuner.api.protocol import (
|
||||
Role,
|
||||
Finish,
|
||||
@@ -30,6 +31,13 @@ async def lifespan(app: FastAPI): # collects GPU memory
|
||||
torch_gc()
|
||||
|
||||
|
||||
def to_json(data: BaseModel) -> str:
|
||||
try: # pydantic v2
|
||||
return json.dumps(data.model_dump(exclude_unset=True), ensure_ascii=False)
|
||||
except: # pydantic v1
|
||||
return data.json(exclude_unset=True, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
@@ -46,30 +54,37 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK)
|
||||
async def create_chat_completion(request: ChatCompletionRequest):
|
||||
if request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=400, detail="Invalid request")
|
||||
query = request.messages[-1].content
|
||||
if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request")
|
||||
|
||||
query = request.messages[-1].content
|
||||
prev_messages = request.messages[:-1]
|
||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||
prefix = prev_messages.pop(0).content
|
||||
system = prev_messages.pop(0).content
|
||||
else:
|
||||
prefix = None
|
||||
system = None
|
||||
|
||||
history = []
|
||||
if len(prev_messages) % 2 == 0:
|
||||
for i in range(0, len(prev_messages), 2):
|
||||
if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
|
||||
history.append([prev_messages[i].content, prev_messages[i+1].content])
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...")
|
||||
|
||||
if request.stream:
|
||||
generate = predict(query, history, prefix, request)
|
||||
generate = predict(query, history, system, request)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
|
||||
response, (prompt_length, response_length) = chat_model.chat(
|
||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens,
|
||||
num_return_sequences=request.n
|
||||
)
|
||||
|
||||
usage = ChatCompletionResponseUsage(
|
||||
@@ -78,25 +93,29 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
total_tokens=prompt_length+response_length
|
||||
)
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=0,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=response),
|
||||
choices = [ChatCompletionResponseChoice(
|
||||
index=i,
|
||||
message=ChatMessage(role=Role.ASSISTANT, content=choice),
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
) for i, choice in enumerate(response)]
|
||||
|
||||
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
|
||||
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
|
||||
|
||||
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
|
||||
async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
delta=DeltaMessage(role=Role.ASSISTANT),
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
for new_text in chat_model.stream_chat(
|
||||
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
|
||||
query, history, system,
|
||||
do_sample=request.do_sample,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
max_new_tokens=request.max_tokens
|
||||
):
|
||||
if len(new_text) == 0:
|
||||
continue
|
||||
@@ -107,7 +126,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
finish_reason=None
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=0,
|
||||
@@ -115,13 +134,13 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
||||
finish_reason=Finish.STOP
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
|
||||
yield chunk.json(exclude_unset=True, ensure_ascii=False)
|
||||
yield to_json(chunk)
|
||||
yield "[DONE]"
|
||||
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat_model = ChatModel(*get_infer_args())
|
||||
chat_model = ChatModel()
|
||||
app = create_app(chat_model)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
|
||||
|
||||
@@ -20,9 +20,6 @@ class ModelCard(BaseModel):
|
||||
object: Optional[str] = "model"
|
||||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: Optional[str] = "owner"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: Optional[list] = []
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
@@ -43,6 +40,7 @@ class DeltaMessage(BaseModel):
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
do_sample: Optional[bool] = True
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = 1
|
||||
|
||||
@@ -1,79 +1,74 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import TextIteratorStreamer
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.tuner import load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
|
||||
|
||||
|
||||
class ChatModel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments"
|
||||
) -> None:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
device_map = infer_auto_device_map(self.model, max_memory=get_balanced_memory(self.model))
|
||||
self.model = dispatch_model(self.model, device_map)
|
||||
else:
|
||||
self.model = self.model.cuda()
|
||||
|
||||
self.template = get_template(data_args.template)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.generating_args = generating_args
|
||||
self.tokenizer.padding_side = "left"
|
||||
self.model = dispatch_model(self.model)
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.system_prompt = data_args.system_prompt
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[Dict[str, Any], int]:
|
||||
prefix = prefix or self.source_prefix
|
||||
|
||||
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token)
|
||||
inputs = self.tokenizer([prompt], return_tensors="pt")
|
||||
inputs = inputs.to(self.model.device)
|
||||
prompt_length = len(inputs["input_ids"][0])
|
||||
system = system or self.system_prompt
|
||||
prompt, _ = self.template.encode_oneturn(
|
||||
tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
|
||||
)
|
||||
prompt_length = len(prompt)
|
||||
input_ids = torch.tensor([prompt], device=self.model.device)
|
||||
|
||||
do_sample = input_kwargs.pop("do_sample", None)
|
||||
temperature = input_kwargs.pop("temperature", None)
|
||||
top_p = input_kwargs.pop("top_p", None)
|
||||
top_k = input_kwargs.pop("top_k", None)
|
||||
num_return_sequences = input_kwargs.pop("num_return_sequences", None)
|
||||
repetition_penalty = input_kwargs.pop("repetition_penalty", None)
|
||||
max_length = input_kwargs.pop("max_length", None)
|
||||
max_new_tokens = input_kwargs.pop("max_new_tokens", None)
|
||||
|
||||
gen_kwargs = self.generating_args.to_dict()
|
||||
gen_kwargs.update(dict(
|
||||
input_ids=inputs["input_ids"],
|
||||
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
|
||||
temperature=temperature or gen_kwargs["temperature"],
|
||||
top_p=top_p or gen_kwargs["top_p"],
|
||||
top_k=top_k or gen_kwargs["top_k"],
|
||||
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
|
||||
logits_processor=get_logits_processor()
|
||||
generating_args = self.generating_args.to_dict()
|
||||
generating_args.update(dict(
|
||||
do_sample=do_sample if do_sample is not None else generating_args["do_sample"],
|
||||
temperature=temperature or generating_args["temperature"],
|
||||
top_p=top_p or generating_args["top_p"],
|
||||
top_k=top_k or generating_args["top_k"],
|
||||
num_return_sequences=num_return_sequences or 1,
|
||||
repetition_penalty=repetition_penalty or generating_args["repetition_penalty"],
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
pad_token_id=self.tokenizer.pad_token_id
|
||||
))
|
||||
|
||||
if isinstance(num_return_sequences, int) and num_return_sequences > 1:
|
||||
generating_args["do_sample"] = True
|
||||
|
||||
if max_length:
|
||||
gen_kwargs.pop("max_new_tokens", None)
|
||||
gen_kwargs["max_length"] = max_length
|
||||
generating_args.pop("max_new_tokens", None)
|
||||
generating_args["max_length"] = max_length
|
||||
|
||||
if max_new_tokens:
|
||||
gen_kwargs.pop("max_length", None)
|
||||
gen_kwargs["max_new_tokens"] = max_new_tokens
|
||||
generating_args.pop("max_length", None)
|
||||
generating_args["max_new_tokens"] = max_new_tokens
|
||||
|
||||
gen_kwargs = dict(
|
||||
inputs=input_ids,
|
||||
generation_config=GenerationConfig(**generating_args),
|
||||
logits_processor=get_logits_processor()
|
||||
)
|
||||
|
||||
return gen_kwargs, prompt_length
|
||||
|
||||
@@ -82,14 +77,18 @@ class ChatModel:
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Tuple[str, Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
|
||||
generation_output = self.model.generate(**gen_kwargs)
|
||||
outputs = generation_output.tolist()[0][prompt_length:]
|
||||
response = self.tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
response_length = len(outputs)
|
||||
) -> Tuple[List[str], Tuple[int, int]]:
|
||||
gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
|
||||
generate_output = self.model.generate(**gen_kwargs)
|
||||
response_ids = generate_output[:, prompt_length:]
|
||||
response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
||||
response_length = 0
|
||||
for i in range(len(response_ids)):
|
||||
eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero()
|
||||
response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i])
|
||||
|
||||
return response, (prompt_length, response_length)
|
||||
|
||||
@torch.inference_mode()
|
||||
@@ -97,10 +96,10 @@ class ChatModel:
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
**input_kwargs
|
||||
) -> Generator[str, None, None]:
|
||||
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
|
||||
gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
|
||||
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
||||
gen_kwargs["streamer"] = streamer
|
||||
|
||||
|
||||
@@ -1,70 +1,48 @@
|
||||
import os
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from datasets import concatenate_datasets, interleave_datasets, load_dataset
|
||||
|
||||
from llmtuner.dsets.utils import checksum, EXT2TYPE
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from llmtuner.hparams import ModelArguments, DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def get_dataset(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments"
|
||||
) -> "Dataset":
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
max_samples = data_args.max_samples
|
||||
all_datasets: List["Dataset"] = [] # support multiple datasets
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
|
||||
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
|
||||
if dataset_attr.load_from == "hf_hub":
|
||||
data_path = dataset_attr.dataset_name
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "script":
|
||||
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
||||
data_name = dataset_attr.subset
|
||||
data_files = None
|
||||
elif dataset_attr.load_from == "file":
|
||||
data_path = None
|
||||
data_path, data_name = None, None
|
||||
data_files: List[str] = []
|
||||
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # directory
|
||||
if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is directory
|
||||
for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
|
||||
if data_path is None:
|
||||
data_path = EXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
else:
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file type does not match."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # single file
|
||||
assert data_path == EXT2TYPE.get(file_name.split(".")[-1], None), "file types are not identical."
|
||||
elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # is file
|
||||
data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
|
||||
data_path = EXT2TYPE.get(dataset_attr.dataset_name.split(".")[-1], None)
|
||||
else:
|
||||
@@ -76,24 +54,75 @@ def get_dataset(
|
||||
raise NotImplementedError
|
||||
|
||||
dataset = load_dataset(
|
||||
data_path,
|
||||
path=data_path,
|
||||
name=data_name,
|
||||
data_files=data_files,
|
||||
split=data_args.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
streaming=data_args.streaming,
|
||||
use_auth_token=True if model_args.use_auth_token else None
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=data_args.streaming
|
||||
)
|
||||
|
||||
if max_samples is not None:
|
||||
max_samples_temp = min(len(dataset), max_samples)
|
||||
dataset = dataset.select(range(max_samples_temp))
|
||||
if max_samples is not None: # truncate dataset
|
||||
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
||||
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align datasets
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
def convert_format(examples: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
|
||||
# convert dataset from sharegpt format to alpaca format
|
||||
outputs = {"prompt": [], "query": [], "response": [], "history": []}
|
||||
for msg_list in examples[dataset_attr.messages]:
|
||||
msg_list = msg_list[:len(msg_list) // 2 * 2] # should be multiples of 2
|
||||
if len(msg_list) == 0:
|
||||
continue
|
||||
|
||||
if dataset_attr.source_prefix: # add prefix
|
||||
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix})
|
||||
msg_pairs = []
|
||||
user_role, assistant_role = None, None
|
||||
for idx in range(0, len(msg_list), 2):
|
||||
if user_role is None and assistant_role is None:
|
||||
user_role = msg_list[idx][dataset_attr.role]
|
||||
assistant_role = msg_list[idx + 1][dataset_attr.role]
|
||||
else:
|
||||
if (
|
||||
msg_list[idx][dataset_attr.role] != user_role
|
||||
or msg_list[idx+1][dataset_attr.role] != assistant_role
|
||||
):
|
||||
raise ValueError("Only accepts conversation in u/a/u/a/u/a order.")
|
||||
msg_pairs.append((msg_list[idx][dataset_attr.content], msg_list[idx + 1][dataset_attr.content]))
|
||||
|
||||
if len(msg_pairs) != 0:
|
||||
outputs["prompt"].append(msg_pairs[-1][0])
|
||||
outputs["query"].append("")
|
||||
outputs["response"].append(msg_pairs[-1][1])
|
||||
outputs["history"].append(msg_pairs[:-1])
|
||||
|
||||
return outputs
|
||||
|
||||
if dataset_attr.formatting == "sharegpt": # convert format
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Converting format of dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
convert_format,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
for column_name in ["prompt", "query", "response", "history"]: # align dataset
|
||||
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
|
||||
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
|
||||
|
||||
if dataset_attr.system_prompt: # add system prompt
|
||||
system_prompt = dataset_attr.system_prompt
|
||||
if data_args.streaming:
|
||||
dataset = dataset.map(lambda _: {"system": system_prompt})
|
||||
else:
|
||||
dataset = dataset.add_column("system", [system_prompt] * len(dataset))
|
||||
|
||||
all_datasets.append(dataset)
|
||||
|
||||
@@ -106,7 +135,11 @@ def get_dataset(
|
||||
elif data_args.mix_strategy.startswith("interleave"):
|
||||
if not data_args.streaming:
|
||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)
|
||||
return interleave_datasets(
|
||||
datasets=all_datasets,
|
||||
probabilities=data_args.interleave_probs,
|
||||
seed=data_args.seed,
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
|
||||
@@ -1,74 +1,110 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal
|
||||
import os
|
||||
import tiktoken
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
|
||||
|
||||
from datasets import load_from_disk
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.template import get_template
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: "Dataset",
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo"]
|
||||
) -> "Dataset":
|
||||
column_names = list(dataset.column_names or [])
|
||||
template = get_template(data_args.template)
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
|
||||
for i in range(len(examples["prompt"])):
|
||||
query, response = examples["prompt"][i], examples["response"][i]
|
||||
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
|
||||
history = history if "history" in examples and examples["history"][i] else []
|
||||
prefix = prefix if "prefix" in examples and examples["prefix"][i] else ""
|
||||
yield query, response, history, prefix
|
||||
history = examples["history"][i] if "history" in examples else None
|
||||
system = examples["system"][i] if "system" in examples else None
|
||||
yield query, response, history, system
|
||||
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>)
|
||||
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
|
||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build grouped texts with format `X1 X2 X3 ...`
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=True)
|
||||
|
||||
if hasattr(tokenizer, "add_eos_token"): # for LLaMA tokenizer
|
||||
add_eos_token_flag = getattr(tokenizer, "add_eos_token")
|
||||
setattr(tokenizer, "add_eos_token", True)
|
||||
|
||||
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||
block_size = data_args.max_source_length
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of max_source_length
|
||||
# split by chunks of cutoff_len
|
||||
result = {
|
||||
k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated_examples.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
# make sure the saved tokenizer is the same as the original one
|
||||
if hasattr(tokenizer, "add_eos_token"):
|
||||
setattr(tokenizer, "add_eos_token", add_eos_token_flag)
|
||||
return result
|
||||
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
|
||||
# for input with history, we build multiple input-label pairs just like:
|
||||
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
|
||||
# for multiturn examples, we only mask the prompt part in each prompt-response pair.
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
max_length = data_args.max_source_length + data_args.max_target_length
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
input_ids, labels = [], []
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
total_len = len(source_ids) + len(target_ids)
|
||||
max_source_len = int(data_args.cutoff_len * (len(source_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (len(target_ids) / total_len))
|
||||
|
||||
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)):
|
||||
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
|
||||
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
|
||||
if len(source_ids) > max_source_len:
|
||||
source_ids = source_ids[:max_source_len]
|
||||
if len(target_ids) > max_target_len:
|
||||
target_ids = target_ids[:max_target_len]
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length - 1: # eos token
|
||||
target_ids = target_ids[:data_args.max_target_length - 1]
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
|
||||
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length:
|
||||
break
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
input_ids += source_ids + target_ids + [tokenizer.eos_token_id]
|
||||
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id]
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
@@ -76,109 +112,161 @@ def preprocess_dataset(
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||
# build inputs with format `<bos> X` and labels with format `<bos> Y`
|
||||
def preprocess_packed_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X1 Y1 <eos> <bos> X2 Y2 <eos>`
|
||||
# and labels with format `<ignore> ... <ignore> Y1 <eos> <ignore> ... <ignore> Y2 <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
input_ids, labels = [], []
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""):
|
||||
continue
|
||||
|
||||
for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn(
|
||||
tokenizer, query, response, history, system
|
||||
)):
|
||||
if data_args.train_on_prompt:
|
||||
source_mask = source_ids
|
||||
elif turn_idx != 0 and template.efficient_eos:
|
||||
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1)
|
||||
else:
|
||||
source_mask = [IGNORE_INDEX] * len(source_ids)
|
||||
input_ids += source_ids + target_ids
|
||||
labels += source_mask + target_ids
|
||||
|
||||
if template.efficient_eos:
|
||||
input_ids += [tokenizer.eos_token_id]
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
total_length = len(input_ids)
|
||||
block_size = data_args.cutoff_len
|
||||
# we drop the small remainder, and if the total_length < block_size, we exclude this batch
|
||||
total_length = (total_length // block_size) * block_size
|
||||
# split by chunks of cutoff_len
|
||||
for i in range(0, total_length, block_size):
|
||||
model_inputs["input_ids"].append(input_ids[i: i + block_size])
|
||||
model_inputs["attention_mask"].append([1] * block_size)
|
||||
model_inputs["labels"].append(labels[i: i + block_size])
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build inputs with format `<bos> X` and labels with format `Y <eos>`
|
||||
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
|
||||
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and query != ""):
|
||||
continue
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
|
||||
input_ids, labels = template.encode_oneturn(tokenizer, query, response, history, system)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(target_ids) > data_args.max_target_length:
|
||||
target_ids = target_ids[:data_args.max_target_length]
|
||||
if template.efficient_eos:
|
||||
labels += [tokenizer.eos_token_id]
|
||||
|
||||
model_inputs["input_ids"].append(source_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(source_ids))
|
||||
model_inputs["labels"].append(target_ids)
|
||||
if len(input_ids) > data_args.cutoff_len:
|
||||
input_ids = input_ids[:data_args.cutoff_len]
|
||||
if len(labels) > data_args.cutoff_len:
|
||||
labels = labels[:data_args.cutoff_len]
|
||||
|
||||
model_inputs["input_ids"].append(input_ids)
|
||||
model_inputs["attention_mask"].append([1] * len(input_ids))
|
||||
model_inputs["labels"].append(labels)
|
||||
|
||||
return model_inputs
|
||||
|
||||
def preprocess_pairwise_dataset(examples):
|
||||
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>`
|
||||
model_inputs = {"accept_ids": [], "reject_ids": []}
|
||||
for query, response, history, prefix in construct_example(examples):
|
||||
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token)
|
||||
def preprocess_pairwise_dataset(examples: Dict[str, List[Any]]) -> Dict[str, List[List[int]]]:
|
||||
# build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
|
||||
model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
|
||||
for query, response, history, system in construct_example(examples):
|
||||
if not (isinstance(query, str) and isinstance(response, list) and query != "" and len(response) > 1):
|
||||
continue
|
||||
|
||||
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
|
||||
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False)
|
||||
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False)
|
||||
prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
|
||||
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
|
||||
|
||||
if len(source_ids) > data_args.max_source_length:
|
||||
source_ids = source_ids[:data_args.max_source_length]
|
||||
if len(accept_ids) > data_args.max_target_length - 1: # eos token
|
||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
||||
if len(reject_ids) > data_args.max_target_length - 1: # eos token
|
||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
||||
if template.efficient_eos:
|
||||
chosen_ids += [tokenizer.eos_token_id]
|
||||
rejected_ids += [tokenizer.eos_token_id]
|
||||
|
||||
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
|
||||
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
|
||||
total_len = len(prompt_ids) + max(len(chosen_ids), len(rejected_ids))
|
||||
max_source_len = int(data_args.cutoff_len * (len(prompt_ids) / total_len))
|
||||
max_target_len = int(data_args.cutoff_len * (max(len(chosen_ids), len(rejected_ids)) / total_len))
|
||||
|
||||
if len(prompt_ids) > max_source_len:
|
||||
prompt_ids = prompt_ids[:max_source_len]
|
||||
if len(chosen_ids) > max_target_len:
|
||||
chosen_ids = chosen_ids[:max_target_len]
|
||||
if len(rejected_ids) > max_target_len:
|
||||
rejected_ids = rejected_ids[:max_target_len]
|
||||
|
||||
model_inputs["prompt_ids"].append(prompt_ids)
|
||||
model_inputs["chosen_ids"].append(chosen_ids)
|
||||
model_inputs["rejected_ids"].append(rejected_ids)
|
||||
|
||||
model_inputs["accept_ids"].append(accept_ids)
|
||||
model_inputs["reject_ids"].append(reject_ids)
|
||||
return model_inputs
|
||||
|
||||
def print_supervised_dataset_example(example):
|
||||
def print_supervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
print("label_ids:\n{}".format(example["labels"]))
|
||||
print("labels:\n{}".format(
|
||||
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
|
||||
skip_special_tokens=False)
|
||||
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False)
|
||||
))
|
||||
|
||||
def print_pairwise_dataset_example(example):
|
||||
print("accept_ids:\n{}".format(example["accept_ids"]))
|
||||
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
|
||||
print("reject_ids:\n{}".format(example["reject_ids"]))
|
||||
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
|
||||
def print_pairwise_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("prompt_ids:\n{}".format(example["prompt_ids"]))
|
||||
print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
|
||||
print("chosen_ids:\n{}".format(example["chosen_ids"]))
|
||||
print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
|
||||
print("rejected_ids:\n{}".format(example["rejected_ids"]))
|
||||
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
|
||||
|
||||
def print_unsupervised_dataset_example(example):
|
||||
def print_unsupervised_dataset_example(example: Dict[str, List[int]]) -> None:
|
||||
print("input_ids:\n{}".format(example["input_ids"]))
|
||||
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
|
||||
|
||||
if stage == "pt":
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_pretrain_dataset
|
||||
preprocess_func = preprocess_pretrain_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
elif stage == "sft" and not training_args.predict_with_generate:
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and example["response"])
|
||||
preprocess_function = preprocess_supervised_dataset
|
||||
preprocess_func = preprocess_packed_supervised_dataset if data_args.sft_packing else preprocess_supervised_dataset
|
||||
print_function = print_supervised_dataset_example
|
||||
elif stage == "rm":
|
||||
dataset = dataset.filter(lambda example: example["prompt"] and len(example["response"]) > 1)
|
||||
preprocess_function = preprocess_pairwise_dataset
|
||||
preprocess_func = preprocess_pairwise_dataset
|
||||
print_function = print_pairwise_dataset_example
|
||||
else:
|
||||
dataset = dataset.filter(lambda example: example["prompt"])
|
||||
preprocess_function = preprocess_unsupervised_dataset
|
||||
preprocess_func = preprocess_unsupervised_dataset
|
||||
print_function = print_unsupervised_dataset_example
|
||||
|
||||
if data_args.cache_path is not None and os.path.exists(data_args.cache_path):
|
||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
||||
return load_from_disk(data_args.cache_path)
|
||||
|
||||
with training_args.main_process_first(desc="dataset map pre-processing"):
|
||||
column_names = list(next(iter(dataset)).keys())
|
||||
kwargs = {}
|
||||
if not data_args.streaming:
|
||||
kwargs = dict(
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
load_from_cache_file=(not data_args.overwrite_cache),
|
||||
desc="Running tokenizer on dataset"
|
||||
)
|
||||
|
||||
dataset = dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
preprocess_func,
|
||||
batched=True,
|
||||
remove_columns=column_names,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
||||
if data_args.cache_path is not None and not os.path.exists(data_args.cache_path):
|
||||
if training_args.should_save:
|
||||
dataset.save_to_disk(data_args.cache_path)
|
||||
raise SystemExit("Dataset saved, rerun this script with the same `--cache_path`.")
|
||||
|
||||
if stage == "pt":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "sft":
|
||||
print_supervised_dataset_example(next(iter(dataset)))
|
||||
elif stage == "rm":
|
||||
print_pairwise_dataset_example(next(iter(dataset)))
|
||||
elif stage == "ppo":
|
||||
print_unsupervised_dataset_example(next(iter(dataset)))
|
||||
if training_args.should_log:
|
||||
try:
|
||||
print_function(next(iter(dataset)))
|
||||
except StopIteration:
|
||||
raise RuntimeError("Empty dataset!")
|
||||
|
||||
return dataset
|
||||
|
||||
@@ -1,15 +1,61 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, IterableDataset
|
||||
from transformers import TrainingArguments
|
||||
from llmtuner.hparams import DataArguments
|
||||
|
||||
|
||||
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
||||
if do_train:
|
||||
if dev_ratio > 1e-6: # Split the dataset
|
||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"arrow": "arrow",
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json",
|
||||
"parquet": "parquet",
|
||||
"txt": "text"
|
||||
}
|
||||
|
||||
|
||||
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
|
||||
if file_sha1 is None:
|
||||
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
|
||||
return
|
||||
|
||||
if len(data_files) != 1:
|
||||
logger.warning("Checksum failed: too many files.")
|
||||
return
|
||||
|
||||
with open(data_files[0], "rb") as f:
|
||||
sha1 = hashlib.sha1(f.read()).hexdigest()
|
||||
if sha1 != file_sha1:
|
||||
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"],
|
||||
data_args: "DataArguments",
|
||||
training_args: "TrainingArguments"
|
||||
) -> Dict[str, "Dataset"]:
|
||||
if training_args.do_train:
|
||||
if data_args.val_size > 1e-6: # Split the dataset
|
||||
if data_args.streaming:
|
||||
val_set = dataset.take(int(data_args.val_size))
|
||||
train_set = dataset.skip(int(data_args.val_size))
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||
else:
|
||||
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
||||
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||
else:
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||
return {"train_dataset": dataset}
|
||||
else: # do_eval or do_predict
|
||||
return {"eval_dataset": dataset}
|
||||
|
||||
@@ -5,67 +5,151 @@ from typing import TYPE_CHECKING
|
||||
from datetime import timedelta
|
||||
|
||||
from transformers import TrainerCallback
|
||||
from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from llmtuner.extras.constants import LOG_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainingArguments, TrainerState, TrainerControl
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SavePeftModelCallback(TrainerCallback):
|
||||
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
"""
|
||||
if args.should_save:
|
||||
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step))
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(output_dir)
|
||||
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if args.should_save:
|
||||
model = kwargs.pop("model")
|
||||
if getattr(model, "is_peft_model", False):
|
||||
getattr(model, "pretrained_model").save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
class LogCallback(TrainerCallback):
|
||||
|
||||
def __init__(self, runner=None):
|
||||
self.runner = runner
|
||||
self.in_training = False
|
||||
self.start_time = time.time()
|
||||
self.tracker = {}
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
self.elapsed_time = ""
|
||||
self.remaining_time = ""
|
||||
|
||||
def timing(self):
|
||||
cur_time = time.time()
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
|
||||
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
|
||||
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
|
||||
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
|
||||
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
"""
|
||||
self.start_time = time.time()
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = True
|
||||
self.start_time = time.time()
|
||||
self.max_steps = state.max_steps
|
||||
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir:
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
|
||||
|
||||
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of a training step. If using gradient accumulation, one training step
|
||||
might take several inputs.
|
||||
Event called at the end of training.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
if state.is_local_process_zero:
|
||||
self.in_training = False
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
"""
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
"""
|
||||
if state.is_local_process_zero:
|
||||
self.cur_steps = state.global_step
|
||||
self.timing()
|
||||
if self.runner is not None and self.runner.aborted:
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
"""
|
||||
if state.is_local_process_zero and not self.in_training:
|
||||
self.cur_steps = 0
|
||||
self.max_steps = 0
|
||||
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
"""
|
||||
if not state.is_world_process_zero:
|
||||
if not state.is_local_process_zero:
|
||||
return
|
||||
|
||||
cur_time = time.time()
|
||||
cur_steps = state.log_history[-1].get("step")
|
||||
elapsed_time = cur_time - self.start_time
|
||||
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
|
||||
remaining_steps = state.max_steps - cur_steps
|
||||
remaining_time = remaining_steps * avg_time_per_step
|
||||
self.tracker = {
|
||||
"current_steps": cur_steps,
|
||||
"total_steps": state.max_steps,
|
||||
"loss": state.log_history[-1].get("loss", None),
|
||||
"eval_loss": state.log_history[-1].get("eval_loss", None),
|
||||
"predict_loss": state.log_history[-1].get("predict_loss", None),
|
||||
"reward": state.log_history[-1].get("reward", None),
|
||||
"learning_rate": state.log_history[-1].get("learning_rate", None),
|
||||
"epoch": state.log_history[-1].get("epoch", None),
|
||||
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
|
||||
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
|
||||
"remaining_time": str(timedelta(seconds=int(remaining_time)))
|
||||
}
|
||||
logs = dict(
|
||||
current_steps=self.cur_steps,
|
||||
total_steps=self.max_steps,
|
||||
loss=state.log_history[-1].get("loss", None),
|
||||
eval_loss=state.log_history[-1].get("eval_loss", None),
|
||||
predict_loss=state.log_history[-1].get("predict_loss", None),
|
||||
reward=state.log_history[-1].get("reward", None),
|
||||
learning_rate=state.log_history[-1].get("learning_rate", None),
|
||||
epoch=state.log_history[-1].get("epoch", None),
|
||||
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
|
||||
elapsed_time=self.elapsed_time,
|
||||
remaining_time=self.remaining_time
|
||||
)
|
||||
if self.runner is not None:
|
||||
logger.info("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format(
|
||||
logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0
|
||||
))
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(self.tracker) + "\n")
|
||||
f.write(json.dumps(logs) + "\n")
|
||||
|
||||
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a prediction step.
|
||||
"""
|
||||
eval_dataloader = kwargs.pop("eval_dataloader", None)
|
||||
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
|
||||
if self.max_steps == 0:
|
||||
self.max_steps = len(eval_dataloader)
|
||||
self.cur_steps += 1
|
||||
self.timing()
|
||||
|
||||
@@ -1,13 +1,19 @@
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
LOG_FILE_NAME = "trainer_log.jsonl"
|
||||
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings
|
||||
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp", "ln_1", "ln_2", "ln1", "ln2"]
|
||||
|
||||
METHODS = ["full", "freeze", "lora"]
|
||||
|
||||
TRAINING_STAGES = {
|
||||
"Supervised Fine-Tuning": "sft",
|
||||
"Reward Modeling": "rm",
|
||||
"PPO": "ppo",
|
||||
"DPO": "dpo",
|
||||
"Pre-Training": "pt"
|
||||
}
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"LLaMA-7B": "huggyllama/llama-7b",
|
||||
"LLaMA-13B": "huggyllama/llama-13b",
|
||||
@@ -19,29 +25,68 @@ SUPPORTED_MODELS = {
|
||||
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
|
||||
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
|
||||
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
|
||||
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
|
||||
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
|
||||
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
|
||||
"BLOOM-560M": "bigscience/bloom-560m",
|
||||
"BLOOM-3B": "bigscience/bloom-3b",
|
||||
"BLOOM-7B1": "bigscience/bloom-7b1",
|
||||
"BLOOMZ-560M": "bigscience/bloomz-560m",
|
||||
"BLOOMZ-3B": "bigscience/bloomz-3b",
|
||||
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
|
||||
"Falcon-7B-Base": "tiiuae/falcon-7b",
|
||||
"Falcon-7B": "tiiuae/falcon-7b",
|
||||
"Falcon-40B": "tiiuae/falcon-40b",
|
||||
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
|
||||
"Falcon-40B-Base": "tiiuae/falcon-40b",
|
||||
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
|
||||
"Baichuan-7B": "baichuan-inc/Baichuan-7B",
|
||||
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
|
||||
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"InternLM-7B-Base": "internlm/internlm-7b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b"
|
||||
"Baichuan2-7B": "baichuan-inc/Baichuan2-7B-Base",
|
||||
"Baichuan2-13B": "baichuan-inc/Baichuan2-13B-Base",
|
||||
"Baichuan2-7B-Chat": "baichuan-inc/Baichuan2-7B-Chat",
|
||||
"Baichuan2-13B-Chat": "baichuan-inc/Baichuan2-13B-Chat",
|
||||
"InternLM-7B": "internlm/internlm-7b",
|
||||
"InternLM-20B": "internlm/internlm-20b",
|
||||
"InternLM-7B-Chat": "internlm/internlm-chat-7b",
|
||||
"InternLM-20B-Chat": "internlm/internlm-chat-20b",
|
||||
"Qwen-7B": "Qwen/Qwen-7B",
|
||||
"Qwen-14B": "Qwen/Qwen-14B",
|
||||
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
|
||||
"Qwen-14B-Chat": "Qwen/Qwen-14B-Chat",
|
||||
"XVERSE-13B": "xverse/XVERSE-13B",
|
||||
"XVERSE-13B-Chat": "xverse/XVERSE-13B-Chat",
|
||||
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b",
|
||||
"ChatGLM3-6B-Base": "THUDM/chatglm3-6b-base",
|
||||
"ChatGLM3-6B-Chat": "THUDM/chatglm3-6b",
|
||||
"Phi1.5-1.3B": "microsoft/phi-1_5"
|
||||
}
|
||||
|
||||
DEFAULT_MODULE = {
|
||||
"LLaMA": "q_proj,v_proj",
|
||||
"LLaMA2": "q_proj,v_proj",
|
||||
"ChineseLLaMA2": "q_proj,v_proj",
|
||||
"BLOOM": "query_key_value",
|
||||
"BLOOMZ": "query_key_value",
|
||||
"Falcon": "query_key_value",
|
||||
"Baichuan": "W_pack",
|
||||
"InternLM": "q_proj,v_proj"
|
||||
"Baichuan2": "W_pack",
|
||||
"InternLM": "q_proj,v_proj",
|
||||
"Qwen": "c_attn",
|
||||
"XVERSE": "q_proj,v_proj",
|
||||
"ChatGLM2": "query_key_value",
|
||||
"ChatGLM3": "query_key_value",
|
||||
"Phi1.5": "Wqkv"
|
||||
}
|
||||
|
||||
DEFAULT_TEMPLATE = {
|
||||
"LLaMA2": "llama2",
|
||||
"ChineseLLaMA2": "llama2_zh",
|
||||
"Baichuan": "baichuan",
|
||||
"Baichuan2": "baichuan2",
|
||||
"InternLM": "intern",
|
||||
"Qwen": "chatml",
|
||||
"XVERSE": "xverse",
|
||||
"ChatGLM2": "chatglm2",
|
||||
"ChatGLM3": "chatglm3"
|
||||
}
|
||||
|
||||
@@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
|
||||
super().__init__()
|
||||
self.log = ""
|
||||
|
||||
def reset(self):
|
||||
self.log = ""
|
||||
|
||||
def emit(self, record):
|
||||
if record.name == "httpx":
|
||||
return
|
||||
|
||||
@@ -1,10 +1,20 @@
|
||||
import gc
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
try:
|
||||
from transformers.utils import (
|
||||
is_torch_bf16_cpu_available,
|
||||
is_torch_bf16_gpu_available,
|
||||
is_torch_cuda_available,
|
||||
is_torch_npu_available
|
||||
)
|
||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||
_is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available
|
||||
except ImportError:
|
||||
_is_fp16_available = torch.cuda.is_available()
|
||||
_is_bf16_available = torch.cuda.is_bf16_supported()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
@@ -30,22 +40,6 @@ class AverageMeter:
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
# Avoids runtime error in model.generate(do_sample=True).
|
||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[..., 0] = 1.0
|
||||
return scores
|
||||
|
||||
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InvalidScoreLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
r"""
|
||||
Returns the number of trainable parameters and number of all parameters in the model.
|
||||
@@ -68,52 +62,57 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]:
|
||||
return trainable_params, all_param
|
||||
|
||||
|
||||
# Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
|
||||
# Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_type: str,
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype:
|
||||
r"""
|
||||
Infers the optimal dtype according to the model_dtype and device compatibility.
|
||||
"""
|
||||
if _is_bf16_available and model_dtype == torch.bfloat16:
|
||||
return torch.bfloat16
|
||||
elif _is_fp16_available:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
|
||||
if finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
|
||||
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
|
||||
|
||||
output_layer: torch.nn.Linear = getattr(model, output_layer_name)
|
||||
input_dtype = output_layer.weight.dtype
|
||||
|
||||
class CastOutputToFloat(torch.nn.Sequential):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return super().forward(x.to(input_dtype)).to(torch.float32)
|
||||
|
||||
setattr(model, output_layer_name, CastOutputToFloat(output_layer))
|
||||
|
||||
return model
|
||||
def get_logits_processor() -> LogitsProcessorList:
|
||||
r"""
|
||||
Gets logits processor that removes NaN and Inf logits.
|
||||
"""
|
||||
logits_processor = LogitsProcessorList()
|
||||
logits_processor.append(InfNanRemoveLogitsProcessor())
|
||||
return logits_processor
|
||||
|
||||
|
||||
def torch_gc() -> None:
|
||||
r"""
|
||||
Collects GPU memory.
|
||||
"""
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
|
||||
r"""
|
||||
Dispatches a pre-trained model to GPUs with balanced memory.
|
||||
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
|
||||
"""
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
|
||||
return model
|
||||
|
||||
if torch.cuda.device_count() > 1:
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import infer_auto_device_map, get_balanced_memory
|
||||
|
||||
if model._no_split_modules is None:
|
||||
raise ValueError("The model class needs to implement the `_no_split_modules` attribute.")
|
||||
|
||||
kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules}
|
||||
max_memory = get_balanced_memory(model, **kwargs)
|
||||
# Make sure tied weights are tied before creating the device map.
|
||||
model.tie_weights()
|
||||
device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs)
|
||||
return dispatch_model(model, device_map)
|
||||
else:
|
||||
return model.cuda()
|
||||
|
||||
0
src/llmtuner/extras/patches/__init__.py
Normal file
0
src/llmtuner/extras/patches/__init__.py
Normal file
218
src/llmtuner/extras/patches/llama_patch.py
Normal file
218
src/llmtuner/extras/patches/llama_patch.py
Normal file
@@ -0,0 +1,218 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, Tuple
|
||||
from transformers.utils import logging
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore
|
||||
from flash_attn.bert_padding import pad_input, unpad_input # type: ignore
|
||||
except ImportError:
|
||||
print("FlashAttention-2 is not installed, ignore this if you are not using FlashAttention.")
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Modified from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
class LlamaShiftShortAttention(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
if getattr(self, "num_key_value_groups"):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim)
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1)
|
||||
|
||||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz*n_group, :, groupsz, :)
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
class LlamaFlashAttention2(LlamaAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# LlamaFlashAttention2 attention does not support output_attentions
|
||||
output_attentions = False
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
# FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim)
|
||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
||||
|
||||
if past_key_value is not None: # reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# cast to half precision
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
logger.warning_once("The input hidden states seems to be silently casted in float32.")
|
||||
query_states = query_states.to(self.config.torch_dtype)
|
||||
key_states = key_states.to(self.config.torch_dtype)
|
||||
value_states = value_states.to(self.config.torch_dtype)
|
||||
|
||||
if getattr(self, "num_key_value_groups", None):
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
query_states = query_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2) # (bsz, seq_len, n_heads, head_dim)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift
|
||||
groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
|
||||
assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
|
||||
num_groups = q_len // groupsz
|
||||
def shift(state: torch.Tensor) -> torch.Tensor:
|
||||
state = torch.cat((
|
||||
state[:, :, :self.num_heads//2], state[:, :, self.num_heads//2:].roll(-groupsz//2, dims=1)
|
||||
), dim=2)
|
||||
return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim)
|
||||
|
||||
query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states)
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.reshape(bsz * num_groups, groupsz)
|
||||
|
||||
if attention_mask is not None:
|
||||
logger.warning_once("Padded sequences are less efficient in FlashAttention.")
|
||||
# -q_len: assumes left padding when q_len != kv_len
|
||||
unpadded_q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask[:, -q_len:])
|
||||
unpadded_k, _, cu_seqlens_k, max_seqlen_k = unpad_input(key_states, attention_mask)
|
||||
unpadded_v, _, _, _ = unpad_input(value_states, attention_mask)
|
||||
attn_output_unpad = flash_attn_varlen_func(
|
||||
unpadded_q,
|
||||
unpadded_k,
|
||||
unpadded_v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = pad_input(attn_output_unpad, indices_q, bsz, q_len)
|
||||
else:
|
||||
attn_output = flash_attn_func(
|
||||
query_states, key_states, value_states, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
|
||||
if getattr(self.config, "group_size_ratio", None) and self.training: # shift back
|
||||
attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim)
|
||||
attn_output = torch.cat((
|
||||
attn_output[:, :, :self.num_heads//2], attn_output[:, :, self.num_heads//2:].roll(groupsz//2, dims=1)
|
||||
))
|
||||
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
||||
# takes a boolean padding_mask. Fills in the past kv length for use in forward.
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_shape: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
past_key_values_length: int
|
||||
) -> torch.Tensor:
|
||||
if attention_mask is not None and torch.all(attention_mask):
|
||||
return None # This uses the faster call when training with full samples
|
||||
|
||||
return attention_mask
|
||||
@@ -1,49 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import Dict
|
||||
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.modeling_utils import load_sharded_checkpoint
|
||||
|
||||
from llmtuner.extras.constants import VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
||||
state_dict: Dict[str, torch.Tensor] = model.state_dict()
|
||||
filtered_state_dict = {}
|
||||
|
||||
for k, v in model.named_parameters():
|
||||
if v.requires_grad:
|
||||
filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
|
||||
|
||||
return filtered_state_dict
|
||||
|
||||
|
||||
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
||||
if os.path.exists(weights_file):
|
||||
model_state_dict = torch.load(weights_file, map_location="cpu")
|
||||
model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
||||
elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
|
||||
load_sharded_checkpoint(model, checkpoint_dir, strict=False)
|
||||
else:
|
||||
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
|
||||
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
|
||||
if not os.path.exists(valuehead_file):
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
|
||||
return False
|
||||
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
|
||||
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
|
||||
return True
|
||||
@@ -1,85 +1,428 @@
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import tiktoken
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
prefix: str
|
||||
prompt: str
|
||||
sep: str
|
||||
prefix: List[Union[str, Dict[str, str]]]
|
||||
prompt: List[Union[str, Dict[str, str]]]
|
||||
system: str
|
||||
sep: List[Union[str, Dict[str, str]]]
|
||||
stop_words: List[str]
|
||||
use_history: bool
|
||||
efficient_eos: bool
|
||||
|
||||
def get_prompt(
|
||||
def encode_oneturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = "",
|
||||
eos_token: Optional[str] = "</s>"
|
||||
) -> str:
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
r"""
|
||||
Returns a string containing prompt without response.
|
||||
Returns a single pair of token ids representing prompt and response respectively.
|
||||
"""
|
||||
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix)))
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
prompt_ids = []
|
||||
for query_ids, resp_ids in encoded_pairs[:-1]:
|
||||
prompt_ids = prompt_ids + query_ids + resp_ids
|
||||
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
|
||||
return prompt_ids, answer_ids
|
||||
|
||||
def get_dialog(
|
||||
def encode_multiturn(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
system: Optional[str] = None
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Returns multiple pairs of token ids representing prompts and responses respectively.
|
||||
"""
|
||||
system, history = self._format(query, resp, history, system)
|
||||
encoded_pairs = self._encode(tokenizer, system, history)
|
||||
return encoded_pairs
|
||||
|
||||
def _format(
|
||||
self,
|
||||
query: str,
|
||||
resp: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = ""
|
||||
) -> List[Tuple[str, str]]:
|
||||
system: Optional[str] = None
|
||||
) -> Tuple[str, List[Tuple[str, str]]]:
|
||||
r"""
|
||||
Returns a list containing prompt-response pairs.
|
||||
Aligns inputs to the standard format.
|
||||
"""
|
||||
result = self._format_example(query, history, prefix)
|
||||
result[-1][-1] = resp
|
||||
return result
|
||||
|
||||
def _format_example(
|
||||
self,
|
||||
query: str,
|
||||
history: Optional[List[Tuple[str, str]]] = None,
|
||||
prefix: Optional[str] = ""
|
||||
) -> List[Tuple[str, str]]:
|
||||
prefix = prefix or self.prefix # use prefix if provided
|
||||
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
|
||||
system = system or self.system # use system if provided
|
||||
history = history if (history and self.use_history) else []
|
||||
history = history + [(query, "")]
|
||||
convs = [
|
||||
[(self.sep if turn_idx else prefix) + self.prompt.format(query=query_i), resp_i]
|
||||
for turn_idx, (query_i, resp_i) in enumerate(history)
|
||||
]
|
||||
return convs
|
||||
history = history + [(query, resp)]
|
||||
return system, history
|
||||
|
||||
def _get_special_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True):
|
||||
bos_ids = [tokenizer.bos_token_id]
|
||||
else: # baichuan, qwen and gpt2 models have no bos token
|
||||
bos_ids = []
|
||||
|
||||
if tokenizer.eos_token_id is None:
|
||||
raise ValueError("EOS token is required.")
|
||||
|
||||
if self.efficient_eos: # used in baichuan, qwen, chatglm, etc.
|
||||
eos_ids = []
|
||||
else:
|
||||
eos_ids = [tokenizer.eos_token_id]
|
||||
|
||||
return bos_ids, eos_ids
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + sep + query resp + eos
|
||||
Turn t: sep + bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0:
|
||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
|
||||
if len(prefix_ids) != 0: # has prefix
|
||||
prefix_ids = bos_ids + prefix_ids + sep_ids
|
||||
else:
|
||||
prefix_ids = bos_ids
|
||||
else:
|
||||
prefix_ids = sep_ids + bos_ids
|
||||
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
|
||||
return encoded_pairs
|
||||
|
||||
def _convert_inputs_to_ids(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
context: List[Union[str, Dict[str, str]]],
|
||||
system: Optional[str] = None,
|
||||
query: Optional[str] = None,
|
||||
idx: Optional[str] = None
|
||||
) -> List[int]:
|
||||
r"""
|
||||
Converts context to token ids.
|
||||
"""
|
||||
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
|
||||
kwargs = dict(allowed_special="all")
|
||||
else:
|
||||
kwargs = dict(add_special_tokens=False)
|
||||
|
||||
token_ids = []
|
||||
for elem in context:
|
||||
if isinstance(elem, str):
|
||||
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
|
||||
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
|
||||
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
|
||||
if len(elem) != 0:
|
||||
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
|
||||
elif isinstance(elem, dict):
|
||||
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
||||
else:
|
||||
raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem)))
|
||||
|
||||
return token_ids
|
||||
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
system: str,
|
||||
history: List[Tuple[str, str]]
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
r"""
|
||||
Encodes formatted inputs to pairs of token ids.
|
||||
Turn 0: bos + prefix + query resp + eos
|
||||
Turn t: bos + query resp + eos
|
||||
"""
|
||||
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||
encoded_pairs = []
|
||||
for turn_idx, (query, resp) in enumerate(history):
|
||||
if turn_idx == 0: # llama2 template has no sep_ids
|
||||
query = self.prefix[0].replace("{{system}}", system) + query
|
||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
||||
return encoded_pairs
|
||||
|
||||
|
||||
templates: Dict[str, Template] = {}
|
||||
|
||||
|
||||
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None:
|
||||
templates[name] = Template(
|
||||
def register_template(
|
||||
name: str,
|
||||
prefix: List[Union[str, Dict[str, str]]],
|
||||
prompt: List[Union[str, Dict[str, str]]],
|
||||
system: str,
|
||||
sep: List[Union[str, Dict[str, str]]],
|
||||
stop_words: Optional[List[str]] = [],
|
||||
use_history: Optional[bool] = True,
|
||||
efficient_eos: Optional[bool] = False
|
||||
) -> None:
|
||||
template_class = Llama2Template if "llama2" in name else Template
|
||||
templates[name] = template_class(
|
||||
prefix=prefix,
|
||||
prompt=prompt,
|
||||
system=system,
|
||||
sep=sep,
|
||||
use_history=use_history
|
||||
stop_words=stop_words,
|
||||
use_history=use_history,
|
||||
efficient_eos=efficient_eos
|
||||
)
|
||||
|
||||
|
||||
def get_template(name: str) -> Template:
|
||||
def get_template_and_fix_tokenizer(
|
||||
name: str,
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
) -> Template:
|
||||
if tokenizer.eos_token_id is None:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
if name is None:
|
||||
return None
|
||||
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
tokenizer.add_special_tokens(
|
||||
dict(additional_special_tokens=template.stop_words),
|
||||
replace_additional_special_tokens=False
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix="",
|
||||
prompt="{query}",
|
||||
sep="",
|
||||
use_history=False
|
||||
name="alpaca",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BAAI/AquilaChat-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-7B
|
||||
https://huggingface.co/BAAI/AquilaChat2-34B
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}###Assistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
),
|
||||
sep=[
|
||||
"###"
|
||||
],
|
||||
stop_words=[
|
||||
"</s>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<reserved_102>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_103>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat
|
||||
https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan2",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<reserved_106>"}, # user token
|
||||
"{{query}}",
|
||||
{"token": "<reserved_107>"} # assistant token
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nBelle: "
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/vivo-ai/BlueLM-7B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="bluelm",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "[|Human|]:"},
|
||||
"{{query}}",
|
||||
{"token": "[|AI|]:"}
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm2-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm2",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/THUDM/chatglm3-6b
|
||||
"""
|
||||
register_template(
|
||||
name="chatglm3",
|
||||
prefix=[
|
||||
{"token": "[gMASK]"},
|
||||
{"token": "sop"},
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n",
|
||||
"{{query}}",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
stop_words=[
|
||||
"<|user|>",
|
||||
"<|observation|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/deepseek-ai/deepseek-coder-1.3b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-6.7b-instruct
|
||||
https://huggingface.co/deepseek-ai/deepseek-coder-33b-instruct
|
||||
"""
|
||||
register_template(
|
||||
name="deepseek",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"### Instruction:\n{{query}}\n\n### Response:\n"
|
||||
],
|
||||
system=(
|
||||
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
||||
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
||||
"For politically sensitive questions, security and privacy issues, "
|
||||
"and other non-computer science questions, you will refuse to answer."
|
||||
),
|
||||
sep=[
|
||||
"\n",
|
||||
{"token": "<|EOT|>"},
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|EOT|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
@@ -88,11 +431,45 @@ Default template.
|
||||
"""
|
||||
register_template(
|
||||
name="default",
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\nAssistant:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
https://huggingface.co/internlm/internlm-chat-20b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"<|User|>:{{query}}",
|
||||
{"token": "<eoh>"},
|
||||
"\n<|Bot|>:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<eoa>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<eoa>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
@@ -103,130 +480,110 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
|
||||
"""
|
||||
register_template(
|
||||
name="llama2",
|
||||
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n",
|
||||
prompt=" [INST] {query} [/INST] ",
|
||||
sep="",
|
||||
use_history=True
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
|
||||
https://github.com/ymcui/Chinese-LLaMA-Alpaca
|
||||
Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
|
||||
https://huggingface.co/ziqingyang/chinese-alpaca-2-13b
|
||||
"""
|
||||
register_template(
|
||||
name="alpaca",
|
||||
prefix="Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.",
|
||||
prompt="### Instruction:\n{query}\n\n### Response:\n",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
name="llama2_zh",
|
||||
prefix=[
|
||||
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
|
||||
https://huggingface.co/lmsys/vicuna-13b-delta-v1.1
|
||||
Supports: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
prefix="A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
||||
prompt="USER: {query} ASSISTANT: ",
|
||||
sep="",
|
||||
use_history=True
|
||||
name="mistral",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"[INST] {{query}} [/INST]"
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
|
||||
Supports: https://huggingface.co/openchat/openchat_3.5
|
||||
"""
|
||||
register_template(
|
||||
name="belle",
|
||||
prefix="",
|
||||
prompt="Human: {query}\n\nBelle: ",
|
||||
sep="\n\n",
|
||||
use_history=True
|
||||
name="openchat",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"GPT4 Correct User: {{query}}",
|
||||
{"token": "<|end_of_turn|>"},
|
||||
"GPT4 Correct Assistant:"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end_of_turn|>"}
|
||||
],
|
||||
stop_words=[
|
||||
"<|end_of_turn|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/CVI-SZU/Linly
|
||||
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
|
||||
https://huggingface.co/Qwen/Qwen-14B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="linly",
|
||||
prefix="",
|
||||
prompt="User: {query}\nBot: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://github.com/Neutralzz/BiLLa
|
||||
"""
|
||||
register_template(
|
||||
name="billa",
|
||||
prefix="",
|
||||
prompt="Human: {query}\nAssistant: ",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix="",
|
||||
prompt="<human>:{query}\n<bot>:",
|
||||
sep="\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/qhduan/aquilachat-7b
|
||||
"""
|
||||
register_template(
|
||||
name="aquila",
|
||||
prefix="A chat between a curious human and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
||||
prompt="Human: {query}###Assistant: ",
|
||||
sep="###",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/internlm/internlm-chat-7b
|
||||
"""
|
||||
register_template(
|
||||
name="intern",
|
||||
prefix="",
|
||||
prompt="<|User|>:{query}<eoh>\n<|Bot|>:",
|
||||
sep="<eoa>\n",
|
||||
use_history=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="baichuan",
|
||||
prefix="",
|
||||
prompt="<reserved_102>{query}<reserved_103>",
|
||||
sep="",
|
||||
use_history=True
|
||||
name="qwen",
|
||||
prefix=[
|
||||
{"token": "<|im_start|>"},
|
||||
"system\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|im_start|>"},
|
||||
"user\n{{query}}",
|
||||
{"token": "<|im_end|>"},
|
||||
"\n",
|
||||
{"token": "<|im_start|>"},
|
||||
"assistant\n"
|
||||
],
|
||||
system="You are a helpful assistant.",
|
||||
sep=[
|
||||
{"token": "<|im_end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|im_end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
@@ -236,8 +593,158 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
|
||||
"""
|
||||
register_template(
|
||||
name="starchat",
|
||||
prefix="<|system|>\n",
|
||||
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n",
|
||||
sep="<|end|>\n",
|
||||
use_history=True
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{query}}",
|
||||
{"token": "<|end|>"},
|
||||
"\n",
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
{"token": "<|end|>"},
|
||||
"\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|end|>"
|
||||
],
|
||||
efficient_eos=True
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports language model inference without histories.
|
||||
"""
|
||||
register_template(
|
||||
name="vanilla",
|
||||
prefix=[],
|
||||
prompt=[
|
||||
"{{query}}"
|
||||
],
|
||||
system="",
|
||||
sep=[],
|
||||
use_history=False
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5
|
||||
https://huggingface.co/lmsys/vicuna-13b-v1.5
|
||||
"""
|
||||
register_template(
|
||||
name="vicuna",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"USER: {{query}} ASSISTANT:"
|
||||
],
|
||||
system=(
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
),
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/xverse/XVERSE-7B-Chat
|
||||
https://huggingface.co/xverse/XVERSE-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="xverse",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
"Human: {{query}}\n\nAssistant: "
|
||||
],
|
||||
system="",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/wenge-research/yayi-7b
|
||||
https://huggingface.co/wenge-research/yayi-7b-llama2
|
||||
https://huggingface.co/wenge-research/yayi-13b-llama2
|
||||
"""
|
||||
register_template(
|
||||
name="yayi",
|
||||
prefix=[
|
||||
{"token": "<|System|>"},
|
||||
":\n{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|Human|>"},
|
||||
":\n{{query}}\n\n",
|
||||
{"token": "<|YaYi|>"},
|
||||
":"
|
||||
],
|
||||
system=(
|
||||
"You are a helpful, respectful and honest assistant named YaYi "
|
||||
"developed by Beijing Wenge Technology Co.,Ltd. "
|
||||
"Always answer as helpfully as possible, while being safe. "
|
||||
"Your answers should not include any harmful, unethical, "
|
||||
"racist, sexist, toxic, dangerous, or illegal content. "
|
||||
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
||||
"If a question does not make any sense, or is not factually coherent, "
|
||||
"explain why instead of answering something not correct. "
|
||||
"If you don't know the answer to a question, please don't share false information."
|
||||
),
|
||||
sep=[
|
||||
"\n\n"
|
||||
],
|
||||
stop_words=[
|
||||
"<|End|>"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/HuggingFaceH4/zephyr-7b-alpha
|
||||
https://huggingface.co/HuggingFaceH4/zephyr-7b-beta
|
||||
"""
|
||||
register_template(
|
||||
name="zephyr",
|
||||
prefix=[
|
||||
{"token": "<|system|>"},
|
||||
"\n{{system}}",
|
||||
{"token": "</s>"}
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<|user|>"},
|
||||
"\n{{query}}",
|
||||
{"token": "</s>"},
|
||||
{"token": "<|assistant|>"}
|
||||
],
|
||||
system="You are a friendly chatbot who always responds in the style of a pirate",
|
||||
sep=[]
|
||||
)
|
||||
|
||||
|
||||
r"""
|
||||
Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1
|
||||
https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat
|
||||
"""
|
||||
register_template(
|
||||
name="ziya",
|
||||
prefix=[
|
||||
"{{system}}"
|
||||
],
|
||||
prompt=[
|
||||
{"token": "<human>"},
|
||||
":{{query}}\n",
|
||||
{"token": "<bot>"},
|
||||
":"
|
||||
],
|
||||
system="",
|
||||
sep=[
|
||||
"\n"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from .data_args import DataArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
from .general_args import GeneralArguments
|
||||
from .generating_args import GeneratingArguments
|
||||
from .model_args import ModelArguments
|
||||
|
||||
@@ -10,28 +10,34 @@ class DatasetAttr:
|
||||
load_from: str
|
||||
dataset_name: Optional[str] = None
|
||||
dataset_sha1: Optional[str] = None
|
||||
source_prefix: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
subset: Optional[str] = None
|
||||
ranking: Optional[bool] = False
|
||||
formatting: Optional[Literal["alpaca", "sharegpt"]] = "alpaca"
|
||||
|
||||
prompt: Optional[str] = "instruction"
|
||||
query: Optional[str] = "input"
|
||||
response: Optional[str] = "output"
|
||||
history: Optional[str] = None
|
||||
messages: Optional[str] = "conversations"
|
||||
role: Optional[str] = "from"
|
||||
content: Optional[str] = "value"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.dataset_name
|
||||
|
||||
def __post_init__(self):
|
||||
self.prompt = "instruction"
|
||||
self.query = "input"
|
||||
self.response = "output"
|
||||
self.history = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
||||
"""
|
||||
template: str = field(
|
||||
template: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which template to use for constructing prompts in training and inference."}
|
||||
)
|
||||
dataset: Optional[str] = field(
|
||||
default="alpaca_en",
|
||||
default=None,
|
||||
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
|
||||
)
|
||||
dataset_dir: Optional[str] = field(
|
||||
@@ -42,17 +48,29 @@ class DataArguments:
|
||||
default="train",
|
||||
metadata={"help": "Which dataset split to use for training and evaluation."}
|
||||
)
|
||||
cutoff_len: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={"help": "The maximum length of the model inputs after tokenization."}
|
||||
)
|
||||
train_on_prompt: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to disable the mask on the prompt or not."}
|
||||
)
|
||||
streaming: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable streaming mode."}
|
||||
metadata={"help": "Enable dataset streaming."}
|
||||
)
|
||||
buffer_size: Optional[int] = field(
|
||||
default=16384,
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
|
||||
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}
|
||||
)
|
||||
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(
|
||||
default="concat",
|
||||
metadata={"help": "Strategy to use in dataset mixing."}
|
||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}
|
||||
)
|
||||
interleave_probs: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
|
||||
)
|
||||
overwrite_cache: Optional[bool] = field(
|
||||
default=False,
|
||||
@@ -62,14 +80,6 @@ class DataArguments:
|
||||
default=None,
|
||||
metadata={"help": "The number of processes to use for the preprocessing."}
|
||||
)
|
||||
max_source_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total input sequence length after tokenization."}
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum total output sequence length after tokenization."}
|
||||
)
|
||||
max_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
|
||||
@@ -82,26 +92,50 @@ class DataArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
|
||||
)
|
||||
source_prefix: Optional[str] = field(
|
||||
system_prompt: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||
metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
|
||||
)
|
||||
dev_ratio: Optional[float] = field(
|
||||
val_size: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
||||
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||
)
|
||||
sft_packing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Packing the questions and answers in the supervised fine-tuning stage."}
|
||||
)
|
||||
cache_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to save or load the preprocessed datasets."}
|
||||
)
|
||||
|
||||
def init_for_training(self): # support mixing multiple datasets
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")]
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
def __post_init__(self):
|
||||
if self.streaming and self.val_size > 1e-6 and self.val_size < 1:
|
||||
raise ValueError("Streaming mode should have an integer val size.")
|
||||
|
||||
if self.source_prefix is not None:
|
||||
prefix_list = self.source_prefix.split("|")
|
||||
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list
|
||||
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
|
||||
else:
|
||||
prefix_list = [None] * len(dataset_names)
|
||||
if self.streaming and self.max_samples is not None:
|
||||
raise ValueError("`max_samples` is incompatible with `streaming`.")
|
||||
|
||||
if self.streaming and self.cache_path:
|
||||
raise ValueError("`cache_path` is incompatible with `streaming`.")
|
||||
|
||||
def init_for_training(self, seed: int): # support mixing multiple datasets
|
||||
self.seed = seed
|
||||
dataset_names = [ds.strip() for ds in self.dataset.split(",")] if self.dataset is not None else []
|
||||
try:
|
||||
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
|
||||
dataset_info = json.load(f)
|
||||
except Exception:
|
||||
if self.dataset is not None:
|
||||
raise ValueError("Cannot find dataset_info.json in `dataset_dir`.")
|
||||
dataset_info = None
|
||||
|
||||
prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
|
||||
prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
|
||||
assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
|
||||
|
||||
if self.interleave_probs is not None:
|
||||
self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
|
||||
|
||||
self.dataset_list: List[DatasetAttr] = []
|
||||
for i, name in enumerate(dataset_names):
|
||||
@@ -119,12 +153,17 @@ class DataArguments:
|
||||
dataset_sha1=dataset_info[name].get("file_sha1", None)
|
||||
)
|
||||
|
||||
dataset_attr.source_prefix = prefix_list[i]
|
||||
|
||||
if "columns" in dataset_info[name]:
|
||||
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
|
||||
dataset_attr.query = dataset_info[name]["columns"].get("query", None)
|
||||
dataset_attr.response = dataset_info[name]["columns"].get("response", None)
|
||||
dataset_attr.history = dataset_info[name]["columns"].get("history", None)
|
||||
dataset_attr.messages = dataset_info[name]["columns"].get("messages", None)
|
||||
dataset_attr.role = dataset_info[name]["columns"].get("role", None)
|
||||
dataset_attr.content = dataset_info[name]["columns"].get("content", None)
|
||||
|
||||
dataset_attr.subset = dataset_info[name].get("subset", None)
|
||||
dataset_attr.ranking = dataset_info[name].get("ranking", False)
|
||||
dataset_attr.formatting = dataset_info[name].get("formatting", "alpaca")
|
||||
dataset_attr.system_prompt = prompt_list[i]
|
||||
self.dataset_list.append(dataset_attr)
|
||||
|
||||
@@ -5,32 +5,29 @@ from dataclasses import asdict, dataclass, field
|
||||
|
||||
@dataclass
|
||||
class FinetuningArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to which techniques we are going to fine-tuning with.
|
||||
"""
|
||||
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field(
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
|
||||
default="lora",
|
||||
metadata={"help": "Which fine-tuning method to use."}
|
||||
)
|
||||
num_hidden_layers: Optional[int] = field(
|
||||
default=32,
|
||||
metadata={"help": "Number of decoder blocks in the model. \
|
||||
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
|
||||
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
|
||||
BLOOM choices: [\"24\", \"30\", \"70\"], \
|
||||
Falcon choices: [\"32\", \"60\"], \
|
||||
Baichuan choices: [\"32\", \"40\"]"}
|
||||
)
|
||||
num_layer_trainable: Optional[int] = field(
|
||||
default=3,
|
||||
metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
|
||||
metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
|
||||
)
|
||||
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
|
||||
default="mlp",
|
||||
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
|
||||
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
|
||||
Baichuan choices: [\"mlp\", \"self_attn\"]"}
|
||||
metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
|
||||
LLaMA choices: [\"mlp\", \"self_attn\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"mlp\", \"self_attention\"], \
|
||||
Qwen choices: [\"mlp\", \"attn\"], \
|
||||
Phi-1.5 choices: [\"mlp\", \"mixer\"], \
|
||||
LLaMA-2, BlueLM, Baichuan, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
lora_rank: Optional[int] = field(
|
||||
default=8,
|
||||
@@ -45,35 +42,74 @@ class FinetuningArguments:
|
||||
metadata={"help": "Dropout rate for the LoRA fine-tuning."}
|
||||
)
|
||||
lora_target: Optional[str] = field(
|
||||
default="q_proj,v_proj",
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
|
||||
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"}
|
||||
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
BLOOM & Falcon & ChatGLM choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
|
||||
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
|
||||
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
|
||||
Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \
|
||||
LLaMA-2, BlueLM, InternLM, Mistral, Skywork, XVERSE, Yi choices: the same as LLaMA."}
|
||||
)
|
||||
additional_target: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
ppo_score_norm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Use score normalization in PPO training."}
|
||||
)
|
||||
ppo_logger: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."}
|
||||
)
|
||||
ppo_target: Optional[float] = field(
|
||||
default=6.0,
|
||||
metadata={"help": "Target KL value for adaptive KL control in PPO training."}
|
||||
)
|
||||
dpo_beta: Optional[float] = field(
|
||||
default=0.1,
|
||||
metadata={"help": "The beta parameter for the DPO loss."}
|
||||
)
|
||||
dpo_ref_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the reference model used for the DPO training."}
|
||||
)
|
||||
dpo_ref_model_checkpoint: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints of the reference model."}
|
||||
)
|
||||
upcast_layernorm: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to upcast the layernorm weights in fp32."}
|
||||
)
|
||||
neft_alpha: Optional[float] = field(
|
||||
default=0,
|
||||
metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA
|
||||
self.lora_target = [target.strip() for target in self.lora_target.split(",")]
|
||||
|
||||
if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [self.num_hidden_layers - k - 1 for k in range(self.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
|
||||
if isinstance(self.additional_target, str):
|
||||
self.additional_target = [target.strip() for target in self.additional_target.split(",")]
|
||||
|
||||
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
|
||||
|
||||
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
|
||||
def save_to_json(self, json_path: str):
|
||||
"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
r"""Saves the content of this instance in JSON format inside `json_path`."""
|
||||
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
|
||||
with open(json_path, "w", encoding="utf-8") as f:
|
||||
f.write(json_string)
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_path: str):
|
||||
"""Creates an instance from the content of `json_path`."""
|
||||
r"""Creates an instance from the content of `json_path`."""
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
return cls(**json.loads(text))
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralArguments:
|
||||
"""
|
||||
Arguments pertaining to which stage we are going to perform.
|
||||
"""
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field(
|
||||
default="sft",
|
||||
metadata={"help": "Which stage will be performed in training."}
|
||||
)
|
||||
@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
|
||||
|
||||
@dataclass
|
||||
class GeneratingArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to specify the decoding parameters.
|
||||
"""
|
||||
do_sample: Optional[bool] = field(
|
||||
@@ -28,7 +28,7 @@ class GeneratingArguments:
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
default=512,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
@@ -46,6 +46,8 @@ class GeneratingArguments:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
if args.get("max_new_tokens", -1) > 0:
|
||||
args.pop("max_length", None)
|
||||
else:
|
||||
args.pop("max_new_tokens", None)
|
||||
return args
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import torch
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from dataclasses import asdict, dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
r"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
|
||||
"""
|
||||
model_name_or_path: str = field(
|
||||
@@ -16,21 +15,17 @@ class ModelArguments:
|
||||
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
|
||||
)
|
||||
use_fast_tokenizer: Optional[bool] = field(
|
||||
default=False,
|
||||
default=True,
|
||||
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
|
||||
)
|
||||
use_auth_token: Optional[bool] = field(
|
||||
split_special_tokens: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
|
||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}
|
||||
)
|
||||
model_revision: Optional[str] = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
|
||||
)
|
||||
padding_side: Optional[Literal["left", "right"]] = field(
|
||||
default="left",
|
||||
metadata={"help": "The side on which the model should have padding applied."}
|
||||
)
|
||||
quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The number of bits to quantize the model."}
|
||||
@@ -43,30 +38,51 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether to use double quantization in int4 training or not."}
|
||||
)
|
||||
compute_dtype: Optional[torch.dtype] = field(
|
||||
rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
|
||||
default=None,
|
||||
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
|
||||
metadata={"help": "Adopt scaled rotary positional embeddings."}
|
||||
)
|
||||
checkpoint_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
|
||||
metadata={"help": "Path to the directory(s) containing the model checkpoints as well as the configurations."}
|
||||
)
|
||||
flash_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable FlashAttention-2 for faster training."}
|
||||
)
|
||||
shift_attn: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}
|
||||
)
|
||||
reward_model: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
|
||||
)
|
||||
resume_lora_training: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
|
||||
)
|
||||
plot_loss: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."}
|
||||
)
|
||||
export_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the directory to save the exported model."}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
self.model_max_length = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
|
||||
if self.checkpoint_dir is not None: # support merging multiple lora weights
|
||||
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
|
||||
|
||||
if self.quantization_bit is not None:
|
||||
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@@ -1,5 +1 @@
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.tune import export_model, run_exp
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from llmtuner.tuner.core.parser import get_train_args, get_infer_args
|
||||
from llmtuner.tuner.core.loader import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.utils import generate_model_card
|
||||
|
||||
@@ -2,16 +2,17 @@ import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from transformers.utils import cached_file
|
||||
from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME
|
||||
from peft import (
|
||||
PeftModel,
|
||||
TaskType,
|
||||
LoraConfig,
|
||||
get_peft_model
|
||||
)
|
||||
from peft.utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import load_trainable_params
|
||||
from llmtuner.tuner.core.utils import find_all_linear_modules
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
@@ -25,8 +26,7 @@ def init_adapter(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: bool,
|
||||
is_mergeable: bool
|
||||
is_trainable: bool
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Initializes the adapters.
|
||||
@@ -36,37 +36,35 @@ def init_adapter(
|
||||
Note that the trainable parameters must be cast to float32.
|
||||
"""
|
||||
|
||||
if finetuning_args.finetuning_type == "none" and is_trainable:
|
||||
raise ValueError("You cannot use finetuning_type=none while training.")
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.info("Checkpoint is not found at evaluation, load the original model.")
|
||||
|
||||
if finetuning_args.finetuning_type == "full":
|
||||
if finetuning_args.finetuning_type == "full" and is_trainable:
|
||||
logger.info("Fine-tuning method: Full")
|
||||
model = model.float()
|
||||
|
||||
if finetuning_args.finetuning_type == "freeze":
|
||||
logger.info("Fine-tuning method: Freeze")
|
||||
num_layers = getattr(model.config, "num_layers")
|
||||
if finetuning_args.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
|
||||
trainable_layer_ids = [num_layers - k - 1 for k in range(finetuning_args.num_layer_trainable)]
|
||||
else: # fine-tuning the first n layers if num_layer_trainable < 0
|
||||
trainable_layer_ids = [k for k in range(-finetuning_args.num_layer_trainable)]
|
||||
|
||||
trainable_layers = ["{:d}.{}".format(idx, finetuning_args.name_module_trainable) for idx in trainable_layer_ids]
|
||||
for name, param in model.named_parameters():
|
||||
if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
|
||||
if not any(trainable_layer in name for trainable_layer in trainable_layers):
|
||||
param.requires_grad_(False)
|
||||
else:
|
||||
param.data = param.data.to(torch.float32)
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
|
||||
|
||||
if finetuning_args.finetuning_type == "lora":
|
||||
logger.info("Fine-tuning method: LoRA")
|
||||
latest_checkpoint = None
|
||||
checkpoint_to_resume = None
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
|
||||
"Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
|
||||
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
|
||||
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
|
||||
|
||||
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights
|
||||
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
if is_trainable and finetuning_args.resume_lora_training:
|
||||
checkpoints_to_merge, checkpoint_to_resume = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
|
||||
else:
|
||||
checkpoints_to_merge = model_args.checkpoint_dir
|
||||
|
||||
@@ -77,17 +75,23 @@ def init_adapter(
|
||||
if len(checkpoints_to_merge) > 0:
|
||||
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
|
||||
|
||||
if latest_checkpoint is not None: # resume lora training or quantized inference
|
||||
model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=is_trainable)
|
||||
if checkpoint_to_resume is not None: # resume lora training
|
||||
model = PeftModel.from_pretrained(model, checkpoint_to_resume, is_trainable=is_trainable)
|
||||
|
||||
if is_trainable and checkpoint_to_resume is None: # create new lora weights while training
|
||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||
target_modules = find_all_linear_modules(model, model_args.quantization_bit)
|
||||
else:
|
||||
target_modules = finetuning_args.lora_target
|
||||
|
||||
if is_trainable and latest_checkpoint is None: # create new lora weights while training
|
||||
lora_config = LoraConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
inference_mode=False,
|
||||
r=finetuning_args.lora_rank,
|
||||
lora_alpha=finetuning_args.lora_alpha,
|
||||
lora_dropout=finetuning_args.lora_dropout,
|
||||
target_modules=finetuning_args.lora_target
|
||||
target_modules=target_modules,
|
||||
modules_to_save=finetuning_args.additional_target
|
||||
)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
@@ -95,3 +99,30 @@ def init_adapter(
|
||||
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_valuehead_params(
|
||||
model: "PreTrainedModel",
|
||||
model_args: "ModelArguments"
|
||||
) -> bool:
|
||||
kwargs = {
|
||||
"path_or_repo_id": model_args.reward_model,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"token": model_args.hf_hub_token,
|
||||
"revision": model_args.model_revision
|
||||
}
|
||||
try:
|
||||
vhead_file = cached_file(filename=WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
try:
|
||||
vhead_file = cached_file(filename=SAFE_WEIGHTS_NAME, **kwargs)
|
||||
except:
|
||||
logger.warning("Provided path ({}) does not contain valuehead weights.".format(model_args.reward_model))
|
||||
return False
|
||||
|
||||
vhead_params = torch.load(vhead_file, map_location="cpu")
|
||||
model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False)
|
||||
model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False)
|
||||
model.register_buffer("default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False)
|
||||
model.register_buffer("default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False)
|
||||
return True
|
||||
|
||||
@@ -1,38 +1,48 @@
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig
|
||||
BitsAndBytesConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.models.llama import modeling_llama as LlamaModule
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from peft import PeftModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
except ImportError: # https://github.com/huggingface/transformers/releases/tag/v4.33.1
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.extras.misc import count_parameters, infer_optim_dtype
|
||||
from llmtuner.extras.patches import llama_patch as LlamaPatches
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from llmtuner.tuner.core.adapter import init_adapter
|
||||
from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params
|
||||
from llmtuner.tuner.core.utils import prepare_model_for_training
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
from llmtuner.hparams import ModelArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
check_min_version("4.29.1")
|
||||
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
|
||||
require_version("transformers>=4.31.0,<4.35.0", "To fix: pip install \"transformers>=4.31.0,<4.35.0\"")
|
||||
require_version("datasets>=2.14.0", "To fix: pip install datasets>=2.14.0")
|
||||
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
|
||||
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
|
||||
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
|
||||
require_version("peft>=0.6.0", "To fix: pip install peft>=0.6.0")
|
||||
require_version("trl==0.7.2", "To fix: pip install trl==0.7.2")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
@@ -40,49 +50,109 @@ def load_model_and_tokenizer(
|
||||
finetuning_args: "FinetuningArguments",
|
||||
is_trainable: Optional[bool] = False,
|
||||
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
|
||||
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
|
||||
) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
|
||||
r"""
|
||||
Loads pretrained model and tokenizer.
|
||||
|
||||
Support both training and inference.
|
||||
"""
|
||||
if (not is_trainable) and model_args.checkpoint_dir is None:
|
||||
logger.warning("Checkpoint is not found at evaluation, load the original model.")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="none")
|
||||
|
||||
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
|
||||
"RM and PPO training can only be performed with the LoRA method."
|
||||
|
||||
config_kwargs = {
|
||||
"trust_remote_code": True,
|
||||
"cache_dir": model_args.cache_dir,
|
||||
"revision": model_args.model_revision,
|
||||
"use_auth_token": True if model_args.use_auth_token else None,
|
||||
"token": model_args.hf_hub_token
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
use_fast=model_args.use_fast_tokenizer,
|
||||
padding_side=model_args.padding_side,
|
||||
split_special_tokens=model_args.split_special_tokens,
|
||||
padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
|
||||
**config_kwargs
|
||||
)
|
||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
|
||||
|
||||
# Fix tokenizer (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer)
|
||||
|
||||
# Set model dtype
|
||||
if model_args.compute_dtype is not None: # for training
|
||||
setattr(config, "torch_dtype", model_args.compute_dtype)
|
||||
else: # for evaluation, priority: bf16 > fp16 > fp32
|
||||
model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None))
|
||||
|
||||
# Fix config (for Qwen)
|
||||
if getattr(config, "model_type", None) == "qwen":
|
||||
for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]:
|
||||
setattr(config, dtype_name, getattr(config, "torch_dtype", None) == dtype)
|
||||
|
||||
# Set RoPE scaling
|
||||
if model_args.rope_scaling is not None:
|
||||
if not hasattr(config, "rope_scaling"):
|
||||
logger.warning("Current model does not support RoPE scaling.")
|
||||
else:
|
||||
if is_trainable:
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.warning(
|
||||
"Dynamic NTK may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
||||
scaling_factor = 1.0
|
||||
else:
|
||||
scaling_factor = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
|
||||
model_args.rope_scaling, scaling_factor
|
||||
))
|
||||
|
||||
# Set FlashAttention-2
|
||||
if model_args.flash_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2
|
||||
LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask
|
||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
||||
elif getattr(config, "model_type", None) in ["qwen", "Yi"]:
|
||||
logger.info("Current model automatically enables FlashAttention if installed.")
|
||||
else:
|
||||
logger.warning("Current model does not support FlashAttention-2.")
|
||||
elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama":
|
||||
LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention
|
||||
logger.warning("Using `--flash_attn` for faster training in large context length.")
|
||||
|
||||
# Set shift short attention (S^2-Attn)
|
||||
if is_trainable and model_args.shift_attn:
|
||||
if getattr(config, "model_type", None) == "llama":
|
||||
setattr(config, "group_size_ratio", 0.25)
|
||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
||||
else:
|
||||
logger.warning("Current model does not support shift short attention.")
|
||||
|
||||
# Quantization configurations (using bitsandbytes library).
|
||||
if model_args.quantization_bit is not None:
|
||||
if is_deepspeed_zero3_enabled():
|
||||
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
|
||||
|
||||
if model_args.quantization_bit == 8:
|
||||
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
|
||||
config_kwargs["load_in_8bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
load_in_8bit=True,
|
||||
llm_int8_threshold=6.0
|
||||
)
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
|
||||
elif model_args.quantization_bit == 4:
|
||||
if model_args.quantization_bit == 4:
|
||||
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
|
||||
config_kwargs["load_in_4bit"] = True
|
||||
config_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||
@@ -92,29 +162,27 @@ def load_model_and_tokenizer(
|
||||
bnb_4bit_quant_type=model_args.quantization_type
|
||||
)
|
||||
|
||||
is_mergeable = False
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if (
|
||||
model_args.quantization_bit is not None
|
||||
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
||||
):
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
model_to_load = model_args.checkpoint_dir[0]
|
||||
else:
|
||||
model_to_load = model_args.model_name_or_path
|
||||
|
||||
# Load and prepare pretrained models (without valuehead).
|
||||
# Load and prepare pre-trained models (without valuehead).
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_to_load,
|
||||
config=config,
|
||||
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
|
||||
torch_dtype=model_args.compute_dtype,
|
||||
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
|
||||
**config_kwargs
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen and Baichuan2)
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2 and ChatGLM3)
|
||||
if getattr(config, "model_type", None) == "chatglm":
|
||||
setattr(model, "lm_head", model.transformer.output_layer)
|
||||
setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"])
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
|
||||
config.__class__.register_for_auto_class()
|
||||
@@ -124,35 +192,42 @@ def load_model_and_tokenizer(
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
|
||||
model = prepare_model_for_training(model=model, finetuning_args=finetuning_args) if is_trainable else model
|
||||
model = init_adapter(model, model_args, finetuning_args, is_trainable)
|
||||
model = model.train() if is_trainable else model.eval()
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
# Prepare model with valuehead for RLHF
|
||||
if stage == "rm" or stage == "ppo":
|
||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded.")
|
||||
if load_valuehead_params(model, model_args):
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "reward_head_weight"),
|
||||
"summary.bias": getattr(model, "reward_head_bias")
|
||||
})
|
||||
|
||||
if stage == "ppo": # load reward model
|
||||
assert is_trainable, "PPO stage cannot be performed at evaluation."
|
||||
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
|
||||
logger.info("Load reward model from {}".format(model_args.reward_model))
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
|
||||
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
|
||||
if isinstance(model.pretrained_model, PeftModel):
|
||||
model.pretrained_model.load_adapter(model_args.reward_model, "reward")
|
||||
for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090
|
||||
if "default" in name:
|
||||
param.data = param.data.to(torch.float32) # trainable params should in fp32
|
||||
assert load_valuehead_params(model, model_args), "Reward model is not correctly loaded."
|
||||
|
||||
# Prepare model for inference
|
||||
if not is_trainable:
|
||||
model.requires_grad_(False) # fix all model params
|
||||
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
|
||||
model = model.to(model_args.compute_dtype) if model_args.quantization_bit is None else model
|
||||
|
||||
trainable_params, all_param = count_parameters(model)
|
||||
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
|
||||
trainable_params, all_param, 100 * trainable_params / all_param
|
||||
))
|
||||
|
||||
if not is_trainable:
|
||||
logger.info("This IS expected that the trainable params is 0 if you are using model for inference only.")
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
@@ -5,21 +5,21 @@ import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.hparams import (
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments,
|
||||
GeneralArguments
|
||||
GeneratingArguments
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None):
|
||||
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
if args is not None:
|
||||
return parser.parse_dict(args)
|
||||
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
@@ -32,26 +32,50 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
|
||||
|
||||
def parse_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def parse_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
parser = HfArgumentParser((
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
))
|
||||
return _parse_args(parser, args)
|
||||
|
||||
|
||||
def get_train_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
|
||||
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args)
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
Seq2SeqTrainingArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args)
|
||||
|
||||
# Setup logging
|
||||
if training_args.should_log:
|
||||
@@ -64,95 +88,137 @@ def get_train_args(
|
||||
transformers.utils.logging.enable_default_handler()
|
||||
transformers.utils.logging.enable_explicit_format()
|
||||
|
||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||
data_args.init_for_training()
|
||||
# Check arguments
|
||||
data_args.init_for_training(training_args.seed)
|
||||
|
||||
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
||||
if finetuning_args.stage != "pt" and data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
||||
"`predict_with_generate` cannot be set as True while training."
|
||||
if finetuning_args.stage != "sft" and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
|
||||
|
||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
||||
"Please enable `predict_with_generate` to save model predictions."
|
||||
if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
|
||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
||||
"Please specify `max_steps` in streaming mode."
|
||||
|
||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
||||
"Streaming mode does not support evaluation currently."
|
||||
|
||||
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
||||
"Streaming mode does not suppport PPO training currently."
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.stage in ["rm", "ppo"]:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
|
||||
if training_args.load_best_model_at_end:
|
||||
raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.")
|
||||
|
||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||
if finetuning_args.stage == "ppo" and not training_args.do_train:
|
||||
raise ValueError("PPO training does not support evaluation.")
|
||||
|
||||
if finetuning_args.stage in ["rm", "dpo"]:
|
||||
for dataset_attr in data_args.dataset_list:
|
||||
if not dataset_attr.ranking:
|
||||
raise ValueError("Please use ranked datasets for reward modeling or DPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.reward_model is None:
|
||||
raise ValueError("Reward model is necessary for PPO training.")
|
||||
|
||||
if finetuning_args.stage == "ppo" and model_args.shift_attn:
|
||||
raise ValueError("PPO training is incompatible with S^2-Attn.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
if training_args.do_train and training_args.predict_with_generate:
|
||||
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||
|
||||
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
|
||||
raise ValueError("Please specify `lora_target` in LoRA training.")
|
||||
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
if training_args.do_train and model_args.quantization_bit is not None and (not finetuning_args.upcast_layernorm):
|
||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||
logger.warning("We recommend enable mixed precision training.")
|
||||
|
||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||
|
||||
if training_args.do_train and (not training_args.fp16):
|
||||
logger.warning("We recommend enable fp16 mixed precision training.")
|
||||
|
||||
# postprocess training_args
|
||||
if (
|
||||
training_args.local_rank != -1
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||
training_args.ddp_find_unused_parameters = False
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(ddp_find_unused_parameters=False))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
if data_args.max_samples is not None and data_args.streaming:
|
||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||
data_args.max_samples = None
|
||||
if (
|
||||
training_args.resume_from_checkpoint is None
|
||||
and training_args.do_train
|
||||
and os.path.isdir(training_args.output_dir)
|
||||
and not training_args.overwrite_output_dir
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
||||
data_args.dev_ratio = 0
|
||||
if last_checkpoint is not None:
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info(
|
||||
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
|
||||
)
|
||||
|
||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if training_args.fp16:
|
||||
model_args.compute_dtype = torch.float16
|
||||
elif training_args.bf16:
|
||||
model_args.compute_dtype = torch.bfloat16
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
torch.bfloat16 if training_args.bf16 else (torch.float16 if training_args.fp16 else None)
|
||||
)
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format(
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1), training_args.fp16
|
||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
||||
))
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
transformers.set_seed(training_args.seed)
|
||||
|
||||
return model_args, data_args, training_args, finetuning_args, general_args
|
||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||
|
||||
|
||||
def get_infer_args(
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||
) -> Tuple[
|
||||
ModelArguments,
|
||||
DataArguments,
|
||||
FinetuningArguments,
|
||||
GeneratingArguments
|
||||
]:
|
||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||
|
||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
||||
"Quantization is only compatible with the LoRA method."
|
||||
if data_args.template is None:
|
||||
raise ValueError("Please specify which `template` to use.")
|
||||
|
||||
if model_args.checkpoint_dir is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
||||
else:
|
||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
||||
"Quantized model only accepts a single checkpoint."
|
||||
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if (
|
||||
model_args.checkpoint_dir is not None
|
||||
and len(model_args.checkpoint_dir) != 1
|
||||
and finetuning_args.finetuning_type != "lora"
|
||||
):
|
||||
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||
|
||||
return model_args, data_args, finetuning_args, generating_args
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, Optional
|
||||
|
||||
from transformers import Seq2SeqTrainer
|
||||
from transformers.trainer import TRAINING_ARGS_NAME, WEIGHTS_NAME
|
||||
from transformers.modeling_utils import PreTrainedModel, unwrap_model
|
||||
from peft import PeftModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
from llmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PeftTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self._remove_log()
|
||||
|
||||
def _remove_log(self):
|
||||
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
|
||||
logger.warning("Previous log file in this folder will be deleted.")
|
||||
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
r"""
|
||||
Saves trainable parameters as model checkpoint.
|
||||
|
||||
This function will only be executed at the process zero.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
logger.info(f"Saving model checkpoint to {output_dir}")
|
||||
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
|
||||
model_state_dict = state_dict or model.state_dict()
|
||||
v_head_state_dict = {
|
||||
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
|
||||
for name in model_state_dict.keys() if name.startswith("v_head.")
|
||||
}
|
||||
|
||||
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
|
||||
model = model.pretrained_model
|
||||
|
||||
state_dict = state_dict or get_state_dict(model)
|
||||
if isinstance(model, (PeftModel, PreTrainedModel)):
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors)
|
||||
model.config.use_cache = False
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
|
||||
|
||||
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
|
||||
f.write(self.args.to_json_string() + "\n")
|
||||
|
||||
self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
|
||||
|
||||
def _load_best_model(self):
|
||||
r"""
|
||||
Loads trainable parameters from model checkpoint.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
"""
|
||||
logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
|
||||
model = unwrap_model(self.model)
|
||||
|
||||
if isinstance(model, PreTrainedModelWrapper):
|
||||
model.v_head.load_state_dict(torch.load(
|
||||
os.path.join(self.state.best_model_checkpoint, VALUE_HEAD_FILE_NAME), map_location="cpu"
|
||||
))
|
||||
model = model.pretrained_model
|
||||
|
||||
if isinstance(model, PeftModel):
|
||||
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
|
||||
else: # freeze/full-tuning
|
||||
load_trainable_params(model, self.state.best_model_checkpoint)
|
||||
107
src/llmtuner/tuner/core/utils.py
Normal file
107
src/llmtuner/tuner/core/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def find_all_linear_modules(
|
||||
model: "PreTrainedModel",
|
||||
quantization_bit: Optional[int] = None
|
||||
) -> List[str]:
|
||||
if quantization_bit is not None:
|
||||
import bitsandbytes as bnb
|
||||
linear_cls = bnb.nn.Linear4bit if quantization_bit == 4 else bnb.nn.Linear8bitLt
|
||||
else:
|
||||
linear_cls = torch.nn.Linear
|
||||
|
||||
output_layer_names = ["lm_head"]
|
||||
if model.config.model_type == "chatglm":
|
||||
output_layer_names.append("output_layer")
|
||||
|
||||
module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, linear_cls)
|
||||
and not any([output_layer in name for output_layer in output_layer_names])
|
||||
):
|
||||
module_names.add(name.split(".")[-1])
|
||||
|
||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
||||
return list(module_names)
|
||||
|
||||
|
||||
def generate_model_card(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"tasks": "text-generation",
|
||||
"finetuned_from": model_args.model_name_or_path,
|
||||
"dataset": [dataset.strip() for dataset in data_args.dataset.split(",")],
|
||||
"tags": ["llama-factory"] + (["lora"] if finetuning_args.finetuning_type == "lora" else [])
|
||||
}
|
||||
|
||||
|
||||
def prepare_model_for_training(
|
||||
model: "PreTrainedModel",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
output_layer_name: Optional[str] = "lm_head",
|
||||
use_gradient_checkpointing: Optional[bool] = True,
|
||||
layernorm_names: Optional[List[str]] = LAYERNORM_NAMES
|
||||
) -> "PreTrainedModel":
|
||||
r"""
|
||||
Includes:
|
||||
(1) cast the layernorm in fp32
|
||||
(2) make output embedding layer require grads
|
||||
(3) upcast the lm_head to fp32
|
||||
Inspired by: https://github.com/huggingface/peft/blob/v0.2.0/src/peft/utils/other.py#L33
|
||||
"""
|
||||
if finetuning_args.upcast_layernorm:
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(ln_name in name for ln_name in layernorm_names):
|
||||
param.data = param.data.to(torch.float32)
|
||||
logger.info("Upcasting weights in layernorm in float32.")
|
||||
|
||||
if finetuning_args.neft_alpha > 1e-6:
|
||||
def neftune_forward_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
if module.training:
|
||||
dims = torch.tensor(output.size(1) * output.size(2))
|
||||
mag_norm = finetuning_args.neft_alpha / torch.sqrt(dims)
|
||||
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
|
||||
return output
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(neftune_forward_hook)
|
||||
logger.info("Using noisy embedding with alpha={:.2f}".format(finetuning_args.neft_alpha))
|
||||
|
||||
if use_gradient_checkpointing:
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
def make_inputs_require_grad(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
output.requires_grad_(True)
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.use_cache = False # turn off when gradient checkpointing is enabled
|
||||
logger.info("Gradient checkpointing enabled.")
|
||||
|
||||
if finetuning_args.finetuning_type != "full" and hasattr(model, output_layer_name):
|
||||
output_layer = getattr(model, output_layer_name)
|
||||
if isinstance(output_layer, torch.nn.Linear):
|
||||
def fp32_forward_pre_hook(module: torch.nn.Module, args: Tuple[torch.Tensor]):
|
||||
return args[0].to(output_layer.weight.dtype)
|
||||
def fp32_forward_post_hook(module: torch.nn.Module, args: Tuple[torch.Tensor], output: torch.Tensor):
|
||||
return output.to(torch.float32)
|
||||
output_layer.register_forward_pre_hook(fp32_forward_pre_hook)
|
||||
output_layer.register_forward_hook(fp32_forward_post_hook)
|
||||
|
||||
return model
|
||||
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
1
src/llmtuner/tuner/dpo/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from llmtuner.tuner.dpo.workflow import run_dpo
|
||||
51
src/llmtuner/tuner/dpo/collator.py
Normal file
51
src/llmtuner/tuner/dpo/collator.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Sequence, Tuple
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
|
||||
|
||||
@dataclass
|
||||
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
"""
|
||||
|
||||
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
|
||||
padded_labels = []
|
||||
for feature, (prompt_len, answer_len) in zip(batch, positions):
|
||||
if self.tokenizer.padding_side == "left":
|
||||
start, end = feature.size(0) - answer_len, feature.size(0)
|
||||
else:
|
||||
start, end = prompt_len, prompt_len + answer_len
|
||||
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
|
||||
padded_tensor[start:end] = feature[start:end]
|
||||
padded_labels.append(padded_tensor)
|
||||
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
|
||||
|
||||
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Pads batched data to the longest sequence in the batch.
|
||||
|
||||
We generate 2 * n examples where the first n examples represent chosen examples and
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
concatenated_features = []
|
||||
label_positions = []
|
||||
for key in ("chosen_ids", "rejected_ids"):
|
||||
for feature in features:
|
||||
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
|
||||
concatenated_features.append({
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (prompt_len + answer_len)
|
||||
})
|
||||
label_positions.append((prompt_len, answer_len))
|
||||
|
||||
batch = self.tokenizer.pad(
|
||||
concatenated_features,
|
||||
padding=self.padding,
|
||||
max_length=self.max_length,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
|
||||
return batch
|
||||
104
src/llmtuner/tuner/dpo/trainer.py
Normal file
104
src/llmtuner/tuner/dpo/trainer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import torch
|
||||
import deepspeed # type: ignore
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union
|
||||
from transformers import BatchEncoding, Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer.utils import disable_dropout_in_model
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import PreTrainedModelWrapper
|
||||
|
||||
|
||||
class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
beta: float,
|
||||
model: Union["PreTrainedModel", torch.nn.Module],
|
||||
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
|
||||
disable_dropout: Optional[bool] = True,
|
||||
loss_type: Optional[Literal["sigmoid", "hinge"]] = "sigmoid",
|
||||
**kwargs
|
||||
):
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if ref_model is not None:
|
||||
disable_dropout_in_model(ref_model)
|
||||
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
self.ref_model = ref_model
|
||||
self.use_dpo_data_collator = True # hack to avoid warning
|
||||
self.generate_during_eval = False # disable at evaluation
|
||||
self.label_pad_token_id = IGNORE_INDEX
|
||||
self.padding_value = 0
|
||||
self.beta = beta
|
||||
self.loss_type = loss_type
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
Trainer.__init__(self, model=model, **kwargs)
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError("Please update `transformers`.")
|
||||
|
||||
if ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def _prepare_deepspeed(self, model: "PreTrainedModelWrapper"):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def concatenated_forward(
|
||||
self,
|
||||
model: Optional[torch.nn.Module] = None,
|
||||
batch: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
|
||||
|
||||
all_logits = model(
|
||||
input_ids=batch_copied["input_ids"],
|
||||
attention_mask=batch_copied["attention_mask"],
|
||||
return_dict=True
|
||||
).logits.to(torch.float32)
|
||||
|
||||
all_logps = self._get_batch_logps(
|
||||
all_logits,
|
||||
batch["labels"],
|
||||
average_log_prob=False
|
||||
)
|
||||
batch_size = batch["input_ids"].size(0) // 2
|
||||
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
|
||||
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits
|
||||
102
src/llmtuner/tuner/dpo/workflow.py
Normal file
102
src/llmtuner/tuner/dpo/workflow.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
|
||||
|
||||
from peft import PeftModel
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.hparams import ModelArguments
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
|
||||
from llmtuner.tuner.dpo.trainer import CustomDPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_dpo(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Create reference model
|
||||
if finetuning_args.dpo_ref_model is not None:
|
||||
ref_model_args_dict = model_args.to_dict()
|
||||
ref_model_args_dict.update(dict(
|
||||
model_name_or_path=finetuning_args.dpo_ref_model,
|
||||
checkpoint_dir=finetuning_args.dpo_ref_model_checkpoint
|
||||
))
|
||||
ref_model_args = ModelArguments(**ref_model_args_dict)
|
||||
ref_model, _ = load_model_and_tokenizer(ref_model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from {}".format(finetuning_args.dpo_ref_model))
|
||||
elif training_args.do_train:
|
||||
if isinstance(model, PeftModel):
|
||||
ref_model = None
|
||||
else:
|
||||
ref_model, _ = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, stage="sft")
|
||||
logger.info("Created reference model from the model itself.")
|
||||
else:
|
||||
ref_model = model
|
||||
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = CustomDPOTrainer(
|
||||
beta=finetuning_args.dpo_beta,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
if id(model) == id(ref_model): # unable to compute rewards without a reference model
|
||||
logger.warning("Pass `dpo_ref_model` for computing rewards at evaluation.")
|
||||
remove_keys = [key for key in metrics.keys() if "rewards" in key]
|
||||
for key in remove_keys:
|
||||
metrics.pop(key)
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
@@ -1,62 +1,78 @@
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import TrainerState, TrainerControl
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from trl import PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
|
||||
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
|
||||
from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from llmtuner.hparams import ModelArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
r"""
|
||||
Inherits PPOTrainer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_args: "ModelArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: List["LogCallback"],
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: List["TrainerCallback"],
|
||||
**kwargs
|
||||
):
|
||||
PPOTrainer.__init__(self, **kwargs)
|
||||
self.args = training_args
|
||||
self.model_args = model_args
|
||||
self.finetuning_args = finetuning_args
|
||||
self.log_callback = callbacks[0]
|
||||
self.generation_config = GenerationConfig(
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||
**generating_args.to_dict()
|
||||
)
|
||||
self.state = TrainerState()
|
||||
self.control = TrainerControl()
|
||||
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
|
||||
self._remove_log()
|
||||
self.log_callback, self.save_callback = callbacks[0], callbacks[1]
|
||||
assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback)
|
||||
if self.args.max_steps > 0:
|
||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||
|
||||
def ppo_train(self, max_target_length: int) -> None:
|
||||
def ppo_train(self) -> None:
|
||||
r"""
|
||||
Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
|
||||
"""
|
||||
total_train_batch_size = (
|
||||
self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
|
||||
)
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||
if self.args.max_steps > 0:
|
||||
num_examples = total_train_batch_size * self.args.max_steps
|
||||
num_train_epochs = sys.maxsize
|
||||
max_steps = self.args.max_steps
|
||||
steps_in_epoch = self.args.max_steps * self.args.gradient_accumulation_steps
|
||||
else:
|
||||
len_dataloader = len(self.dataloader)
|
||||
num_examples = len(self.dataset)
|
||||
num_train_epochs = self.args.num_train_epochs
|
||||
max_steps = math.ceil(num_train_epochs * len_dataloader)
|
||||
steps_in_epoch = len_dataloader
|
||||
|
||||
self.state.max_steps = max_steps
|
||||
self.state.num_train_epochs = num_train_epochs
|
||||
@@ -73,69 +89,59 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": self.tokenizer.pad_token_id,
|
||||
"eos_token_id": self.tokenizer.eos_token_id,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
length_sampler = LengthSampler(max_target_length // 2, max_target_length)
|
||||
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
loss_meter = AverageMeter()
|
||||
reward_meter = AverageMeter()
|
||||
self.log_callback.on_train_begin(self.args, self.state, self.control)
|
||||
|
||||
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
|
||||
batch = next(dataiter)
|
||||
steps_trained += 1
|
||||
for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
|
||||
try:
|
||||
batch = next(dataiter)
|
||||
except StopIteration:
|
||||
dataiter = iter(self.dataloader)
|
||||
batch = next(dataiter)
|
||||
|
||||
# Cast to inference mode
|
||||
unwrapped_model.gradient_checkpointing_disable()
|
||||
unwrapped_model.config.use_cache = True
|
||||
self.model.eval()
|
||||
|
||||
# Get responses
|
||||
query_tensors = batch["input_ids"]
|
||||
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
|
||||
# Get inputs
|
||||
queries, responses = self.get_inputs(batch)
|
||||
self.tokenizer.padding_side = "right" # change padding side
|
||||
rewards = self.get_rewards(queries, responses, unwrapped_model)
|
||||
|
||||
queries, responses = [], []
|
||||
for i in range(len(query_tensors)):
|
||||
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
|
||||
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
queries.append(query_tensors[i, query_length:]) # remove padding from left
|
||||
responses.append(response_tensors[i, :response_length]) # remove padding from right
|
||||
|
||||
# Compute rewards
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
with torch.no_grad():
|
||||
_, _, values = self.model(
|
||||
**self.prepare_model_inputs(queries, responses),
|
||||
output_hidden_states=True,
|
||||
return_dict=True
|
||||
)
|
||||
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
|
||||
replace_model(unwrapped_model, target="default")
|
||||
|
||||
# Run PPO step
|
||||
# Cast to training mode
|
||||
unwrapped_model.gradient_checkpointing_enable()
|
||||
unwrapped_model.config.use_cache = False
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.model.train()
|
||||
|
||||
loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
|
||||
# Run PPO step
|
||||
stats = self.step(queries, responses, rewards)
|
||||
self.tokenizer.padding_side = "left" # restore padding side
|
||||
loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards))
|
||||
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
|
||||
|
||||
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
if self.config.log_with is not None:
|
||||
try:
|
||||
batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True)
|
||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||
self.log_stats(stats, batch, rewards)
|
||||
except:
|
||||
logger.warning("Failed to save stats due to unknown errors.")
|
||||
|
||||
self.state.global_step += 1
|
||||
self.log_callback.on_step_end(self.args, self.state, self.control)
|
||||
|
||||
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
|
||||
logs = dict(
|
||||
loss=round(loss_meter.avg, 4),
|
||||
reward=round(reward_meter.avg, 4),
|
||||
learning_rate=stats["ppo/learning_rate"],
|
||||
epoch=round(step / len_dataloader, 2)
|
||||
epoch=round(step / steps_in_epoch, 2)
|
||||
)
|
||||
print(logs)
|
||||
tqdm.write(str(logs))
|
||||
logs["step"] = step
|
||||
self.state.log_history.append(logs)
|
||||
self.log_callback.on_log(self.args, self.state, self.control)
|
||||
@@ -143,47 +149,155 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
|
||||
reward_meter.reset()
|
||||
|
||||
if (step+1) % self.args.save_steps == 0: # save checkpoint
|
||||
self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
|
||||
self.save_model(os.path.join(
|
||||
self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step)
|
||||
))
|
||||
self.save_callback.on_save(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
if self.control.should_epoch_stop or self.control.should_training_stop:
|
||||
break
|
||||
|
||||
if steps_trained == len_dataloader:
|
||||
dataiter = iter(self.dataloader)
|
||||
steps_trained = 0
|
||||
self.log_callback.on_train_end(self.args, self.state, self.control)
|
||||
self.save_callback.on_train_end(
|
||||
self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model)
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
inputs: Dict[str, torch.Tensor],
|
||||
length_sampler: Optional[Callable] = None,
|
||||
return_prompt: Optional[bool] = True,
|
||||
**generation_kwargs
|
||||
) -> torch.Tensor:
|
||||
def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
r"""
|
||||
Generates model's responses given queries.
|
||||
"""
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
layernorm_params = dump_layernorm(self.model)
|
||||
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
||||
response: torch.Tensor = unwrapped_model.generate(
|
||||
generation_config=self.generation_config,
|
||||
logits_processor=get_logits_processor(),
|
||||
**batch
|
||||
)
|
||||
|
||||
if self.finetuning_args.upcast_layernorm:
|
||||
restore_layernorm(self.model, layernorm_params)
|
||||
|
||||
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
|
||||
queries, responses = [], []
|
||||
for i in range(len(query)):
|
||||
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0].item()
|
||||
response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
|
||||
|
||||
if len(response_index) == 0:
|
||||
response_length = 1 # allow empty response
|
||||
else:
|
||||
response_length = response_index[-1].item() + 1
|
||||
|
||||
queries.append(query[i, query_length:]) # remove padding from left
|
||||
responses.append(response[i, :response_length]) # remove padding from right
|
||||
|
||||
return queries, responses
|
||||
|
||||
@torch.no_grad()
|
||||
def get_rewards(
|
||||
self,
|
||||
queries: List[torch.Tensor],
|
||||
responses: List[torch.Tensor],
|
||||
unwrapped_model: "AutoModelForCausalLMWithValueHead"
|
||||
) -> List[torch.Tensor]:
|
||||
r"""
|
||||
Computes scores using given reward model.
|
||||
"""
|
||||
replace_model(unwrapped_model, target="reward")
|
||||
batch = self.prepare_model_inputs(queries, responses)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
|
||||
|
||||
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
rewards = []
|
||||
for i in range(values.size(0)):
|
||||
end_indexes = (batch["input_ids"][i] != self.tokenizer.eos_token_id).nonzero()
|
||||
end_index = end_indexes[-1].item() if len(end_indexes) else 0
|
||||
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
|
||||
|
||||
replace_model(unwrapped_model, target="default")
|
||||
return rewards
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
queries: torch.Tensor,
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: Optional[bool] = False,
|
||||
response_masks: Optional[torch.Tensor] = None
|
||||
):
|
||||
r"""
|
||||
Calculates model outputs in multiple batches.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
|
||||
bs = len(queries)
|
||||
fbs = self.config.mini_batch_size
|
||||
all_logprobs = []
|
||||
all_logits = []
|
||||
all_masks = []
|
||||
all_values = []
|
||||
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
for i in range(math.ceil(bs / fbs)):
|
||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||
if response_masks is not None:
|
||||
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
||||
input_ids = input_kwargs["input_ids"]
|
||||
attention_mask = input_kwargs["attention_mask"]
|
||||
|
||||
unwrapped_model = self.accelerator.unwrap_model(self.model)
|
||||
with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
response = unwrapped_model.generate(**inputs, **generation_kwargs)
|
||||
if values.size(0) != input_ids.size(0): # adapt to chatglm2
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
|
||||
if unwrapped_model.pretrained_model.generation_config._from_model_config:
|
||||
unwrapped_model.pretrained_model.generation_config._from_model_config = False
|
||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||
masks = torch.zeros_like(attention_mask)
|
||||
masks[:, :-1] = attention_mask[:, 1:]
|
||||
|
||||
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
|
||||
for j in range(len(query_batch)):
|
||||
start = len(query_batch[j]) - 1
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0].item()
|
||||
end = start + len(response_batch[j])
|
||||
|
||||
if not return_prompt and not self.is_encoder_decoder:
|
||||
return response[:, inputs["input_ids"].size(1):]
|
||||
return response
|
||||
if response_masks is not None:
|
||||
response_masks_batch = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
if response_masks is not None:
|
||||
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
||||
|
||||
if return_logits:
|
||||
all_logits.append(logits)
|
||||
else:
|
||||
del logits
|
||||
|
||||
all_values.append(values)
|
||||
all_logprobs.append(logprobs)
|
||||
all_masks.append(masks)
|
||||
|
||||
return (
|
||||
torch.cat(all_logprobs),
|
||||
torch.cat(all_logits)[:, :-1] if return_logits else None,
|
||||
torch.cat(all_values)[:, :-1],
|
||||
torch.cat(all_masks)[:, :-1],
|
||||
)
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
|
||||
@@ -1,39 +1,35 @@
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from llmtuner.extras.constants import LAYERNORM_NAMES
|
||||
from typing import TYPE_CHECKING, Dict, Literal, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["default", "reward"]) -> None:
|
||||
if target == "reward": # save default head temporarily
|
||||
valuehead_state_dict = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
|
||||
valuehead_state_dict: Dict[str, torch.Tensor] = model.v_head.state_dict()
|
||||
setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"].detach().clone())
|
||||
setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"].detach().clone())
|
||||
|
||||
model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
|
||||
model.v_head.load_state_dict({
|
||||
"summary.weight": getattr(model, "{}_head_weight".format(target)),
|
||||
"summary.bias": getattr(model, "{}_head_bias".format(target))
|
||||
"summary.weight": model.get_buffer("{}_head_weight".format(target)).detach().clone(),
|
||||
"summary.bias": model.get_buffer("{}_head_bias".format(target)).detach().clone()
|
||||
})
|
||||
|
||||
|
||||
def cast_layernorm_dtype(
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
layer_norm_names: List[str] = LAYERNORM_NAMES,
|
||||
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
|
||||
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
|
||||
|
||||
layer_norm_state_dict = {}
|
||||
|
||||
def dump_layernorm(model: "PreTrainedModel") -> Dict[str, torch.Tensor]:
|
||||
layer_norm_params = {}
|
||||
for name, param in model.named_parameters():
|
||||
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
|
||||
if layer_norm_params is not None:
|
||||
param.data = layer_norm_params[name] # restore float32 weights
|
||||
else:
|
||||
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
|
||||
param.data = param.data.to(torch.float16)
|
||||
if param.data.dtype == torch.float32:
|
||||
layer_norm_params[name] = param.data.detach().clone()
|
||||
param.data = param.data.to(model.config.torch_dtype)
|
||||
|
||||
return model, layer_norm_state_dict
|
||||
return layer_norm_params
|
||||
|
||||
|
||||
def restore_layernorm(model: "PreTrainedModel", layernorm_params: Optional[Dict[str, torch.Tensor]] = None) -> None:
|
||||
for name, param in model.named_parameters():
|
||||
if name in layernorm_params:
|
||||
param.data = layernorm_params[name]
|
||||
|
||||
@@ -1,23 +1,21 @@
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
|
||||
# Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING
|
||||
from trl import PPOConfig
|
||||
from torch.optim import AdamW
|
||||
from typing import Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorWithPadding
|
||||
from transformers.optimization import get_scheduler
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
|
||||
from llmtuner.tuner.ppo.trainer import CustomPPOTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_ppo(
|
||||
@@ -25,12 +23,15 @@ def run_ppo(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
|
||||
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
|
||||
|
||||
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
|
||||
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
|
||||
|
||||
ppo_config = PPOConfig(
|
||||
model_name=model_args.model_name_or_path,
|
||||
@@ -39,24 +40,39 @@ def run_ppo(
|
||||
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
|
||||
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
|
||||
ppo_epochs=1,
|
||||
max_grad_norm=training_args.max_grad_norm
|
||||
max_grad_norm=training_args.max_grad_norm,
|
||||
seed=training_args.seed,
|
||||
optimize_cuda_cache=True,
|
||||
target=finetuning_args.ppo_target,
|
||||
log_with=finetuning_args.ppo_logger,
|
||||
use_score_scaling=finetuning_args.ppo_score_norm,
|
||||
use_score_norm=finetuning_args.ppo_score_norm,
|
||||
accelerator_kwargs={"step_scheduler_with_optimizer": False}
|
||||
)
|
||||
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
||||
total_train_batch_size = \
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||
if training_args.max_steps > 0:
|
||||
num_training_steps = training_args.max_steps
|
||||
else:
|
||||
total_train_batch_size = (
|
||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||
)
|
||||
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
training_args.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=training_args.warmup_steps,
|
||||
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
||||
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps
|
||||
)
|
||||
|
||||
# Initialize our Trainer
|
||||
ppo_trainer = PPOPeftTrainer(
|
||||
ppo_trainer = CustomPPOTrainer(
|
||||
model_args=model_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=callbacks,
|
||||
generating_args=generating_args,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
config=ppo_config,
|
||||
model=model,
|
||||
ref_model=None,
|
||||
@@ -67,8 +83,10 @@ def run_ppo(
|
||||
lr_scheduler=lr_scheduler
|
||||
)
|
||||
|
||||
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be after save_model
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
ppo_trainer.ppo_train()
|
||||
ppo_trainer.save_model()
|
||||
ppo_trainer.save_state() # must be called after save_model to have a folder
|
||||
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "reward"])
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/language-modeling/run_clm.py
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForLanguageModeling, Trainer
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
@@ -21,34 +18,30 @@ def run_pt(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -61,6 +54,12 @@ def run_pt(
|
||||
perplexity = float("inf")
|
||||
|
||||
metrics["perplexity"] = perplexity
|
||||
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, Sequence
|
||||
from transformers import DataCollatorWithPadding
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
r"""
|
||||
Data collator for pairwise data.
|
||||
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
|
||||
the last n examples represent rejected examples.
|
||||
"""
|
||||
features = [
|
||||
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])}
|
||||
for key in ("accept_ids", "reject_ids") for feature in features
|
||||
{
|
||||
"input_ids": feature["prompt_ids"] + feature[key],
|
||||
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
|
||||
}
|
||||
for key in ("chosen_ids", "rejected_ids") for feature in features
|
||||
]
|
||||
return super().__call__(features)
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import json
|
||||
import torch
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Trainer
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwisePeftTrainer(PeftTrainer):
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute pairwise loss.
|
||||
"""
|
||||
@@ -32,19 +32,51 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
We use score on the EOS token to represent reward of the whole sentence.
|
||||
|
||||
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
|
||||
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
|
||||
"""
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
|
||||
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
|
||||
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
|
||||
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
end_index = chosen_length
|
||||
div_index = end_index - 1
|
||||
else:
|
||||
end_index = max(chosen_length, rejected_length)
|
||||
div_index = check_divergence[0]
|
||||
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length-1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length-1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
if return_outputs:
|
||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||
return loss, [loss, chosen_scores, rejected_scores]
|
||||
|
||||
return loss
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
@@ -60,11 +92,10 @@ class PairwisePeftTrainer(PeftTrainer):
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
|
||||
acc_scores, rej_scores = predict_results.predictions
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for acc_score, rej_score in zip(acc_scores, rej_scores):
|
||||
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
writer.write("\n".join(res))
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
# Inspired by:
|
||||
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
|
||||
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.callbacks import SavePeftModelCallback
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.rm.metric import compute_accuracy
|
||||
from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding
|
||||
from llmtuner.tuner.rm.trainer import PairwisePeftTrainer
|
||||
from llmtuner.tuner.rm.trainer import PairwiseTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
|
||||
|
||||
@@ -22,34 +21,36 @@ def run_rm(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=4)
|
||||
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
# Update arguments
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(remove_unused_columns=False)) # important for pairwise dataset
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwisePeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
callbacks=callbacks + [SavePeftModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -65,3 +66,10 @@ def run_rm(
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
@@ -25,7 +25,7 @@ class ComputeMetrics:
|
||||
Uses the model predictions to compute metrics.
|
||||
"""
|
||||
preds, labels = eval_preds
|
||||
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
|
||||
|
||||
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
|
||||
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
|
||||
@@ -49,6 +49,5 @@ class ComputeMetrics:
|
||||
|
||||
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
|
||||
score_dict["bleu-4"].append(round(bleu_score * 100, 4))
|
||||
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
|
||||
|
||||
return {k: float(np.mean(v)) for k, v in score_dict.items()}
|
||||
|
||||
@@ -4,10 +4,10 @@ import torch
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from transformers import Seq2SeqTrainer
|
||||
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core.trainer import PeftTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers.trainer import PredictionOutput
|
||||
@@ -16,7 +16,7 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
r"""
|
||||
Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
|
||||
"""
|
||||
@@ -33,53 +33,36 @@ class Seq2SeqPeftTrainer(PeftTrainer):
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
"""
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
|
||||
if "attention_mask" in inputs:
|
||||
inputs["attention_mask"] = self._pad_tensors_to_target_len(
|
||||
inputs["attention_mask"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
if "position_ids" in inputs:
|
||||
inputs["position_ids"] = self._pad_tensors_to_target_len(
|
||||
inputs["position_ids"], inputs["labels"], pad_token_id=0
|
||||
)
|
||||
labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels
|
||||
if self.args.predict_with_generate:
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
|
||||
if prompt_len > label_len:
|
||||
inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
|
||||
if label_len > prompt_len:
|
||||
inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs
|
||||
|
||||
loss, generated_tokens, labels = super().prediction_step(
|
||||
loss, generated_tokens, _ = super().prediction_step(
|
||||
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
|
||||
)
|
||||
generated_tokens = (
|
||||
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
|
||||
)
|
||||
if generated_tokens is not None and self.args.predict_with_generate:
|
||||
generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id
|
||||
generated_tokens = generated_tokens.contiguous()
|
||||
|
||||
return (loss, generated_tokens, labels)
|
||||
return loss, generated_tokens, labels
|
||||
|
||||
def _pad_tensors_to_target_len(
|
||||
self,
|
||||
src_tensor: torch.Tensor,
|
||||
tgt_tensor: torch.Tensor,
|
||||
pad_token_id: Optional[int] = None
|
||||
tgt_tensor: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Pads the tensor to the same length as the target tensor.
|
||||
|
||||
Should only be called when predict_with_generate=True.
|
||||
"""
|
||||
if pad_token_id is None:
|
||||
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
|
||||
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
else:
|
||||
if self.model.config.pad_token_id is not None:
|
||||
pad_token_id = self.model.config.pad_token_id
|
||||
else:
|
||||
raise ValueError("Pad_token_id must be set in the configuration of the model.")
|
||||
|
||||
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
|
||||
assert self.tokenizer.pad_token_id is not None, "Pad token is required."
|
||||
padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor)
|
||||
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
|
||||
return padded_tensor
|
||||
return padded_tensor.contiguous() # in contiguous memory
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
|
||||
# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, List
|
||||
from transformers import DataCollatorForSeq2Seq
|
||||
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments
|
||||
|
||||
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import IGNORE_INDEX
|
||||
from llmtuner.extras.misc import get_logits_processor
|
||||
from llmtuner.extras.ploting import plot_loss
|
||||
from llmtuner.tuner.core import load_model_and_tokenizer
|
||||
from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer
|
||||
from llmtuner.tuner.sft.metric import ComputeMetrics
|
||||
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
|
||||
from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
|
||||
from transformers import TrainerCallback
|
||||
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
|
||||
|
||||
def run_sft(
|
||||
@@ -22,50 +21,54 @@ def run_sft(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()]
|
||||
generating_args: "GeneratingArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None
|
||||
):
|
||||
dataset = get_dataset(model_args, data_args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
|
||||
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
|
||||
|
||||
if training_args.predict_with_generate:
|
||||
tokenizer.padding_side = "left" # use left-padding in generation
|
||||
|
||||
data_collator = DataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
pad_to_multiple_of=4 if tokenizer.padding_side == "right" else None, # for shift short attention
|
||||
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
||||
)
|
||||
|
||||
# Override the decoding parameters of Seq2SeqTrainer
|
||||
training_args.generation_max_length = training_args.generation_max_length if \
|
||||
training_args.generation_max_length is not None else data_args.max_target_length
|
||||
training_args.generation_num_beams = data_args.eval_num_beams if \
|
||||
data_args.eval_num_beams is not None else training_args.generation_num_beams
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(
|
||||
generation_max_length=training_args.generation_max_length or data_args.cutoff_len,
|
||||
generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
|
||||
))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = Seq2SeqPeftTrainer(
|
||||
finetuning_args=finetuning_args,
|
||||
trainer = CustomSeq2SeqTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks,
|
||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
||||
**split_dataset(dataset, data_args, training_args)
|
||||
)
|
||||
|
||||
# Keyword arguments for `model.generate`
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.7,
|
||||
"max_new_tokens": data_args.max_target_length + 1,
|
||||
"temperature": 0.95,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
gen_kwargs = generating_args.to_dict()
|
||||
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
|
||||
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
||||
gen_kwargs["logits_processor"] = get_logits_processor()
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
trainer.save_model()
|
||||
if trainer.is_world_process_zero() and model_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
|
||||
|
||||
@@ -85,3 +88,10 @@ def run_sft(
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
if training_args.do_train:
|
||||
if training_args.push_to_hub:
|
||||
trainer.push_to_hub(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
else:
|
||||
trainer.create_model_card(**generate_model_card(model_args, data_args, finetuning_args))
|
||||
|
||||
51
src/llmtuner/tuner/tune.py
Normal file
51
src/llmtuner/tuner/tune.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.tuner.pt import run_pt
|
||||
from llmtuner.tuner.sft import run_sft
|
||||
from llmtuner.tuner.rm import run_rm
|
||||
from llmtuner.tuner.ppo import run_ppo
|
||||
from llmtuner.tuner.dpo import run_dpo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import TrainerCallback
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
|
||||
model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args)
|
||||
callbacks = [LogCallback()] if callbacks is None else callbacks
|
||||
|
||||
if finetuning_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
elif finetuning_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||
elif finetuning_args.stage == "dpo":
|
||||
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
|
||||
else:
|
||||
raise ValueError("Unknown task.")
|
||||
|
||||
|
||||
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.config.use_cache = True
|
||||
model.save_pretrained(model_args.export_dir, max_shard_size=max_shard_size)
|
||||
try:
|
||||
tokenizer.padding_side = "left" # restore padding side
|
||||
tokenizer.init_kwargs["padding_side"] = "left"
|
||||
tokenizer.save_pretrained(model_args.export_dir)
|
||||
except:
|
||||
logger.warning("Cannot save tokenizer, please copy the files manually.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_exp()
|
||||
@@ -0,0 +1 @@
|
||||
from llmtuner.webui.interface import create_ui, create_web_demo
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, *args):
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if len(args) != 0:
|
||||
super().__init__(*args)
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str
|
||||
):
|
||||
if self.model is not None:
|
||||
yield ALERTS["err_exists"][lang]
|
||||
return
|
||||
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
)
|
||||
super().__init__(*get_infer_args(args))
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, lang: str):
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
prefix: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
):
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
response = self.postprocess(response)
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, response]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
101
src/llmtuner/webui/chatter.py
Normal file
101
src/llmtuner/webui/chatter.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from llmtuner.chat.stream_chat import ChatModel
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.hparams import GeneratingArguments
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class WebChatModel(ChatModel):
|
||||
|
||||
def __init__(self, manager: "Manager", lazy_init: Optional[bool] = True) -> None:
|
||||
self.manager = manager
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.generating_args = GeneratingArguments()
|
||||
if not lazy_init:
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def loaded(self) -> bool:
|
||||
return self.model is not None
|
||||
|
||||
def load_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang = get("top.lang")
|
||||
error = ""
|
||||
if self.loaded:
|
||||
error = ALERTS["err_exists"][lang]
|
||||
elif not get("top.model_name"):
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not get("top.model_path"):
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
yield ALERTS["info_loading"][lang]
|
||||
args = dict(
|
||||
model_name_or_path=get("top.model_path"),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None
|
||||
)
|
||||
super().__init__(args)
|
||||
|
||||
yield ALERTS["info_loaded"][lang]
|
||||
|
||||
def unload_model(self, data: Dict[Component, Any]) -> Generator[str, None, None]:
|
||||
lang = data[self.manager.get_elem_by_name("top.lang")]
|
||||
yield ALERTS["info_unloading"][lang]
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
torch_gc()
|
||||
yield ALERTS["info_unloaded"][lang]
|
||||
|
||||
def predict(
|
||||
self,
|
||||
chatbot: List[Tuple[str, str]],
|
||||
query: str,
|
||||
history: List[Tuple[str, str]],
|
||||
system: str,
|
||||
max_new_tokens: int,
|
||||
top_p: float,
|
||||
temperature: float
|
||||
) -> Generator[Tuple[List[Tuple[str, str]], List[Tuple[str, str]]], None, None]:
|
||||
chatbot.append([query, ""])
|
||||
response = ""
|
||||
for new_text in self.stream_chat(
|
||||
query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
):
|
||||
response += new_text
|
||||
new_history = history + [(query, response)]
|
||||
chatbot[-1] = [query, self.postprocess(response)]
|
||||
yield chatbot, new_history
|
||||
|
||||
def postprocess(self, response: str) -> str:
|
||||
blocks = response.split("```")
|
||||
for i, block in enumerate(blocks):
|
||||
if i % 2 == 0:
|
||||
blocks[i] = block.replace("<", "<").replace(">", ">")
|
||||
return "```".join(blocks)
|
||||
@@ -1,12 +1,17 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import json
|
||||
import gradio as gr
|
||||
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
|
||||
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from typing import Any, Dict, Optional
|
||||
from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
)
|
||||
|
||||
from llmtuner.extras.constants import SUPPORTED_MODELS
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE, DEFAULT_TEMPLATE, SUPPORTED_MODELS, TRAINING_STAGES
|
||||
|
||||
|
||||
DEFAULT_CACHE_DIR = "cache"
|
||||
@@ -14,10 +19,18 @@ DEFAULT_DATA_DIR = "data"
|
||||
DEFAULT_SAVE_DIR = "saves"
|
||||
USER_CONFIG = "user.config"
|
||||
DATA_CONFIG = "dataset_info.json"
|
||||
CKPT_NAMES = [
|
||||
WEIGHTS_NAME,
|
||||
WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME
|
||||
]
|
||||
|
||||
|
||||
def get_save_dir(model_name: str) -> str:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
|
||||
def get_save_dir(*args) -> os.PathLike:
|
||||
return os.path.join(DEFAULT_SAVE_DIR, *args)
|
||||
|
||||
|
||||
def get_config_path() -> os.PathLike:
|
||||
@@ -29,36 +42,46 @@ def load_config() -> Dict[str, Any]:
|
||||
with open(get_config_path(), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
return {"last_model": "", "path_dict": {}}
|
||||
return {"lang": None, "last_model": None, "path_dict": {}, "cache_dir": None}
|
||||
|
||||
|
||||
def save_config(model_name: str, model_path: str) -> None:
|
||||
def save_config(lang: str, model_name: Optional[str] = None, model_path: Optional[str] = None) -> None:
|
||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||
user_config = load_config()
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
user_config["lang"] = lang or user_config["lang"]
|
||||
if model_name:
|
||||
user_config["last_model"] = model_name
|
||||
user_config["path_dict"][model_name] = model_path
|
||||
with open(get_config_path(), "w", encoding="utf-8") as f:
|
||||
json.dump(user_config, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
def get_model_path(model_name: str) -> str:
|
||||
user_config = load_config()
|
||||
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
|
||||
return user_config["path_dict"].get(model_name, None) or SUPPORTED_MODELS.get(model_name, "")
|
||||
|
||||
|
||||
def get_module(model_name: str) -> str:
|
||||
return DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj")
|
||||
|
||||
|
||||
def get_template(model_name: str) -> str:
|
||||
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
|
||||
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
|
||||
return "default"
|
||||
|
||||
|
||||
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
|
||||
checkpoints = []
|
||||
save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([
|
||||
os.path.isfile(os.path.join(save_dir, checkpoint, name))
|
||||
for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
|
||||
])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
if model_name:
|
||||
save_dir = get_save_dir(model_name, finetuning_type)
|
||||
if save_dir and os.path.isdir(save_dir):
|
||||
for checkpoint in os.listdir(save_dir):
|
||||
if (
|
||||
os.path.isdir(os.path.join(save_dir, checkpoint))
|
||||
and any([os.path.isfile(os.path.join(save_dir, checkpoint, name)) for name in CKPT_NAMES])
|
||||
):
|
||||
checkpoints.append(checkpoint)
|
||||
return gr.update(value=[], choices=checkpoints)
|
||||
|
||||
|
||||
@@ -67,9 +90,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except:
|
||||
print("Cannot find {} in {}.".format(DATA_CONFIG, dataset_dir))
|
||||
return {}
|
||||
|
||||
|
||||
def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
|
||||
def list_dataset(
|
||||
dataset_dir: Optional[str] = None, training_stage: Optional[str] = list(TRAINING_STAGES.keys())[0]
|
||||
) -> Dict[str, Any]:
|
||||
dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
|
||||
return gr.update(value=[], choices=list(dataset_info.keys()))
|
||||
ranking = TRAINING_STAGES[training_stage] in ["rm", "dpo"]
|
||||
datasets = [k for k, v in dataset_info.items() if v.get("ranking", False) == ranking]
|
||||
return gr.update(value=[], choices=datasets)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from llmtuner.webui.components.top import create_top
|
||||
from llmtuner.webui.components.sft import create_sft_tab
|
||||
from llmtuner.webui.components.train import create_train_tab
|
||||
from llmtuner.webui.components.eval import create_eval_tab
|
||||
from llmtuner.webui.components.infer import create_infer_tab
|
||||
from llmtuner.webui.components.export import create_export_tab
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
@@ -1,37 +1,35 @@
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_chat_box(
|
||||
chat_model: "WebChatModel",
|
||||
engine: "Engine",
|
||||
visible: Optional[bool] = False
|
||||
) -> Tuple["Block", "Component", "Component", Dict[str, "Component"]]:
|
||||
with gr.Box(visible=visible) as chat_box:
|
||||
chatbot = gr.Chatbot()
|
||||
|
||||
history = gr.State([])
|
||||
with gr.Row():
|
||||
with gr.Column(scale=4):
|
||||
prefix = gr.Textbox(show_label=False)
|
||||
system = gr.Textbox(show_label=False)
|
||||
query = gr.Textbox(show_label=False, lines=8)
|
||||
submit_btn = gr.Button(variant="primary")
|
||||
|
||||
with gr.Column(scale=1):
|
||||
clear_btn = gr.Button()
|
||||
max_new_tokens = gr.Slider(10, 2048, value=chat_model.generating_args.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
|
||||
|
||||
history = gr.State([])
|
||||
gen_kwargs = engine.chatter.generating_args
|
||||
max_new_tokens = gr.Slider(10, 2048, value=gen_kwargs.max_new_tokens, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=gen_kwargs.top_p, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=gen_kwargs.temperature, step=0.01)
|
||||
|
||||
submit_btn.click(
|
||||
chat_model.predict,
|
||||
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature],
|
||||
engine.chatter.predict,
|
||||
[chatbot, query, history, system, max_new_tokens, top_p, temperature],
|
||||
[chatbot, history],
|
||||
show_progress=True
|
||||
).then(
|
||||
@@ -41,7 +39,7 @@ def create_chat_box(
|
||||
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
|
||||
|
||||
return chat_box, chatbot, history, dict(
|
||||
prefix=prefix,
|
||||
system=system,
|
||||
query=query,
|
||||
submit_btn=submit_btn,
|
||||
clear_btn=clear_btn,
|
||||
|
||||
@@ -1,21 +1,103 @@
|
||||
import os
|
||||
import json
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, Tuple
|
||||
|
||||
from llmtuner.webui.common import DATA_CONFIG
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.blocks import Block
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"]:
|
||||
with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
|
||||
PAGE_SIZE = 2
|
||||
|
||||
|
||||
def prev_page(page_index: int) -> int:
|
||||
return page_index - 1 if page_index > 0 else page_index
|
||||
|
||||
|
||||
def next_page(page_index: int, total_num: int) -> int:
|
||||
return page_index + 1 if (page_index + 1) * PAGE_SIZE < total_num else page_index
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list, page_index: int) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
|
||||
data_file: str = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
if data_file.endswith(".json"):
|
||||
data = json.load(f)
|
||||
elif data_file.endswith(".jsonl"):
|
||||
data = [json.loads(line) for line in f]
|
||||
else:
|
||||
data = [line for line in f]
|
||||
return len(data), data[PAGE_SIZE * page_index : PAGE_SIZE * (page_index + 1)], gr.update(visible=True)
|
||||
|
||||
|
||||
def create_preview_box(dataset_dir: "gr.Textbox", dataset: "gr.Dropdown") -> Dict[str, "Component"]:
|
||||
data_preview_btn = gr.Button(interactive=False, scale=1)
|
||||
with gr.Column(visible=False, elem_classes="modal-box") as preview_box:
|
||||
with gr.Row():
|
||||
preview_count = gr.Number(interactive=False)
|
||||
preview_count = gr.Number(value=0, interactive=False, precision=0)
|
||||
page_index = gr.Number(value=0, interactive=False, precision=0)
|
||||
|
||||
with gr.Row():
|
||||
prev_btn = gr.Button()
|
||||
next_btn = gr.Button()
|
||||
close_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
preview_samples = gr.JSON(interactive=False)
|
||||
|
||||
close_btn = gr.Button()
|
||||
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
|
||||
|
||||
return preview_box, preview_count, preview_samples, close_btn
|
||||
dataset.change(
|
||||
can_preview, [dataset_dir, dataset], [data_preview_btn], queue=False
|
||||
).then(
|
||||
lambda: 0, outputs=[page_index], queue=False
|
||||
)
|
||||
data_preview_btn.click(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
prev_btn.click(
|
||||
prev_page, [page_index], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
next_btn.click(
|
||||
next_page, [page_index, preview_count], [page_index], queue=False
|
||||
).then(
|
||||
get_preview,
|
||||
[dataset_dir, dataset, page_index],
|
||||
[preview_count, preview_samples, preview_box],
|
||||
queue=False
|
||||
)
|
||||
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
|
||||
return dict(
|
||||
data_preview_btn=data_preview_btn,
|
||||
preview_count=preview_count,
|
||||
page_index=page_index,
|
||||
prev_btn=prev_btn,
|
||||
next_btn=next_btn,
|
||||
close_btn=close_btn,
|
||||
preview_samples=preview_samples
|
||||
)
|
||||
|
||||
@@ -1,76 +1,70 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
def create_eval_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset], queue=False)
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
input_elems.update({dataset_dir, dataset})
|
||||
elem_dict.update(dict(dataset_dir=dataset_dir, dataset=dataset, **preview_elems))
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
|
||||
predict = gr.Checkbox(value=True)
|
||||
|
||||
input_elems.update({cutoff_len, max_samples, batch_size, predict})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, max_samples=max_samples, batch_size=batch_size, predict=predict
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
max_new_tokens = gr.Slider(10, 2048, value=128, step=1)
|
||||
top_p = gr.Slider(0.01, 1, value=0.7, step=0.01)
|
||||
temperature = gr.Slider(0.01, 1.5, value=0.95, step=0.01)
|
||||
|
||||
input_elems.update({max_new_tokens, top_p, temperature})
|
||||
elem_dict.update(dict(
|
||||
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_eval,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
max_samples,
|
||||
batch_size,
|
||||
predict
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
output_elems = [output_box, process_bar]
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box
|
||||
))
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
predict=predict,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_box=output_box
|
||||
)
|
||||
cmd_preview_btn.click(engine.runner.preview_eval, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_eval, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
return elem_dict
|
||||
|
||||
@@ -1,35 +1,78 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict, Generator, List
|
||||
|
||||
from llmtuner.webui.utils import export_model
|
||||
from llmtuner.tuner import export_model
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def save_model(
|
||||
lang: str,
|
||||
model_name: str,
|
||||
model_path: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
template: str,
|
||||
max_shard_size: int,
|
||||
export_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
error = ""
|
||||
if not model_name:
|
||||
error = ALERTS["err_no_model"][lang]
|
||||
elif not model_path:
|
||||
error = ALERTS["err_no_path"][lang]
|
||||
elif not checkpoints:
|
||||
error = ALERTS["err_no_checkpoint"][lang]
|
||||
elif not export_dir:
|
||||
error = ALERTS["err_no_export_dir"][lang]
|
||||
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_path,
|
||||
checkpoint_dir=",".join([get_save_dir(model_name, finetuning_type, ckpt) for ckpt in checkpoints]),
|
||||
finetuning_type=finetuning_type,
|
||||
template=template,
|
||||
export_dir=export_dir
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
export_model(args, max_shard_size="{}GB".format(max_shard_size))
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
|
||||
def create_export_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
save_dir = gr.Textbox()
|
||||
export_dir = gr.Textbox()
|
||||
max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
|
||||
|
||||
export_btn = gr.Button()
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
|
||||
export_btn.click(
|
||||
export_model,
|
||||
save_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
engine.manager.get_elem_by_name("top.lang"),
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.model_path"),
|
||||
engine.manager.get_elem_by_name("top.checkpoints"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
engine.manager.get_elem_by_name("top.template"),
|
||||
max_shard_size,
|
||||
save_dir
|
||||
export_dir
|
||||
],
|
||||
[info_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
save_dir=save_dir,
|
||||
export_dir=export_dir,
|
||||
max_shard_size=max_shard_size,
|
||||
export_btn=export_btn,
|
||||
info_box=info_box
|
||||
|
||||
@@ -1,51 +1,39 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"]:
|
||||
def create_infer_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
load_btn = gr.Button()
|
||||
unload_btn = gr.Button()
|
||||
|
||||
info_box = gr.Textbox(show_label=False, interactive=False)
|
||||
elem_dict.update(dict(load_btn=load_btn, unload_btn=unload_btn, info_box=info_box))
|
||||
|
||||
chat_model = WebChatModel()
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
|
||||
chat_box, chatbot, history, chat_elems = create_chat_box(engine, visible=False)
|
||||
elem_dict.update(dict(chat_box=chat_box, **chat_elems))
|
||||
|
||||
load_btn.click(
|
||||
chat_model.load_model,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"]
|
||||
],
|
||||
[info_box]
|
||||
engine.chatter.load_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
unload_btn.click(
|
||||
chat_model.unload_model, [top_elems["lang"]], [info_box]
|
||||
engine.chatter.unload_model, input_elems, [info_box]
|
||||
).then(
|
||||
lambda: ([], []), outputs=[chatbot, history]
|
||||
).then(
|
||||
lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
|
||||
lambda: gr.update(visible=engine.chatter.loaded), outputs=[chat_box]
|
||||
)
|
||||
|
||||
return dict(
|
||||
info_box=info_box,
|
||||
load_btn=load_btn,
|
||||
unload_btn=unload_btn,
|
||||
**chat_elems
|
||||
)
|
||||
return elem_dict
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import can_preview, get_preview, gen_plot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.runner import Runner
|
||||
|
||||
|
||||
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
|
||||
with gr.Row():
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_btn = gr.Button(interactive=False, scale=1)
|
||||
|
||||
preview_box, preview_count, preview_samples, close_btn = create_preview_box()
|
||||
|
||||
dataset_dir.change(list_dataset, [dataset_dir], [dataset])
|
||||
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
|
||||
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
|
||||
|
||||
with gr.Row():
|
||||
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
|
||||
)
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Row():
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
start_btn.click(
|
||||
runner.run_train,
|
||||
[
|
||||
top_elems["lang"],
|
||||
top_elems["model_name"],
|
||||
top_elems["checkpoints"],
|
||||
top_elems["finetuning_type"],
|
||||
top_elems["quantization_bit"],
|
||||
top_elems["template"],
|
||||
top_elems["source_prefix"],
|
||||
dataset_dir,
|
||||
dataset,
|
||||
max_source_length,
|
||||
max_target_length,
|
||||
learning_rate,
|
||||
num_train_epochs,
|
||||
max_samples,
|
||||
batch_size,
|
||||
gradient_accumulation_steps,
|
||||
lr_scheduler_type,
|
||||
max_grad_norm,
|
||||
dev_ratio,
|
||||
logging_steps,
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
compute_type,
|
||||
lora_rank,
|
||||
lora_dropout,
|
||||
lora_target,
|
||||
output_dir
|
||||
],
|
||||
[output_box]
|
||||
)
|
||||
stop_btn.click(runner.set_abort, queue=False)
|
||||
|
||||
output_box.change(
|
||||
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=dataset,
|
||||
preview_btn=preview_btn,
|
||||
preview_count=preview_count,
|
||||
preview_samples=preview_samples,
|
||||
close_btn=close_btn,
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=learning_rate,
|
||||
num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples,
|
||||
batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=max_grad_norm,
|
||||
dev_ratio=dev_ratio,
|
||||
advanced_tab=advanced_tab,
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
compute_type=compute_type,
|
||||
lora_tab=lora_tab,
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target,
|
||||
start_btn=start_btn,
|
||||
stop_btn=stop_btn,
|
||||
output_dir=output_dir,
|
||||
output_box=output_box,
|
||||
loss_viewer=loss_viewer
|
||||
)
|
||||
@@ -1,10 +1,9 @@
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
|
||||
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
|
||||
from llmtuner.extras.template import templates
|
||||
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config
|
||||
from llmtuner.webui.common import get_model_path, get_template, list_checkpoint, save_config
|
||||
from llmtuner.webui.utils import can_quantize
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,35 +14,47 @@ def create_top() -> Dict[str, "Component"]:
|
||||
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
|
||||
|
||||
with gr.Row():
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
|
||||
lang = gr.Dropdown(choices=["en", "zh"], scale=1)
|
||||
model_name = gr.Dropdown(choices=available_models, scale=3)
|
||||
model_path = gr.Textbox(scale=3)
|
||||
|
||||
with gr.Row():
|
||||
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
checkpoints = gr.Dropdown(multiselect=True, scale=5)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown([8, 4], scale=1)
|
||||
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1)
|
||||
source_prefix = gr.Textbox(scale=2)
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=1)
|
||||
template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
|
||||
system_prompt = gr.Textbox(scale=2)
|
||||
|
||||
with gr.Accordion(label="Model config (LLaMA only)", open=False) as llama_tab:
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
flash_attn = gr.Checkbox(value=False)
|
||||
shift_attn = gr.Checkbox(value=False)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none")
|
||||
|
||||
model_name.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
get_model_path, [model_name], [model_path]
|
||||
get_model_path, [model_name], [model_path], queue=False
|
||||
).then(
|
||||
get_template, [model_name], [template], queue=False
|
||||
) # do not save config since the below line will save
|
||||
model_path.change(save_config, [model_name, model_path])
|
||||
|
||||
model_path.change(save_config, inputs=[lang, model_name, model_path], queue=False)
|
||||
|
||||
finetuning_type.change(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints]
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
).then(
|
||||
can_quantize, [finetuning_type], [quantization_bit]
|
||||
can_quantize, [finetuning_type], [quantization_bit], queue=False
|
||||
)
|
||||
|
||||
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
|
||||
refresh_btn.click(
|
||||
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
|
||||
)
|
||||
|
||||
return dict(
|
||||
lang=lang,
|
||||
@@ -55,5 +66,9 @@ def create_top() -> Dict[str, "Component"]:
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
template=template,
|
||||
source_prefix=source_prefix
|
||||
system_prompt=system_prompt,
|
||||
llama_tab=llama_tab,
|
||||
flash_attn=flash_attn,
|
||||
shift_attn=shift_attn,
|
||||
rope_scaling=rope_scaling
|
||||
)
|
||||
|
||||
154
src/llmtuner/webui/components/train.py
Normal file
154
src/llmtuner/webui/components/train.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import gradio as gr
|
||||
from typing import TYPE_CHECKING, Dict
|
||||
from transformers.trainer_utils import SchedulerType
|
||||
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
|
||||
from llmtuner.webui.components.data import create_preview_box
|
||||
from llmtuner.webui.utils import gen_plot
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
input_elems = engine.manager.get_base_elems()
|
||||
elem_dict = dict()
|
||||
|
||||
with gr.Row():
|
||||
training_stage = gr.Dropdown(
|
||||
choices=list(TRAINING_STAGES.keys()), value=list(TRAINING_STAGES.keys())[0], scale=2
|
||||
)
|
||||
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
|
||||
dataset = gr.Dropdown(multiselect=True, scale=4)
|
||||
preview_elems = create_preview_box(dataset_dir, dataset)
|
||||
|
||||
training_stage.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
dataset_dir.change(list_dataset, [dataset_dir, training_stage], [dataset], queue=False)
|
||||
|
||||
input_elems.update({training_stage, dataset_dir, dataset})
|
||||
elem_dict.update(dict(
|
||||
training_stage=training_stage, dataset_dir=dataset_dir, dataset=dataset, **preview_elems
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
cutoff_len = gr.Slider(value=1024, minimum=4, maximum=8192, step=1)
|
||||
learning_rate = gr.Textbox(value="5e-5")
|
||||
num_train_epochs = gr.Textbox(value="3.0")
|
||||
max_samples = gr.Textbox(value="100000")
|
||||
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
|
||||
|
||||
input_elems.update({cutoff_len, learning_rate, num_train_epochs, max_samples, compute_type})
|
||||
elem_dict.update(dict(
|
||||
cutoff_len=cutoff_len, learning_rate=learning_rate, num_train_epochs=num_train_epochs,
|
||||
max_samples=max_samples, compute_type=compute_type
|
||||
))
|
||||
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
|
||||
lr_scheduler_type = gr.Dropdown(
|
||||
choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
|
||||
)
|
||||
max_grad_norm = gr.Textbox(value="1.0")
|
||||
val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
|
||||
|
||||
input_elems.update({batch_size, gradient_accumulation_steps, lr_scheduler_type, max_grad_norm, val_size})
|
||||
elem_dict.update(dict(
|
||||
batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, val_size=val_size
|
||||
))
|
||||
|
||||
with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
|
||||
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
|
||||
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
|
||||
neft_alpha = gr.Slider(value=0, minimum=0, maximum=10, step=0.1)
|
||||
|
||||
with gr.Column():
|
||||
train_on_prompt = gr.Checkbox(value=False)
|
||||
upcast_layernorm = gr.Checkbox(value=False)
|
||||
|
||||
input_elems.update({logging_steps, save_steps, warmup_steps, neft_alpha, train_on_prompt, upcast_layernorm})
|
||||
elem_dict.update(dict(
|
||||
advanced_tab=advanced_tab, logging_steps=logging_steps, save_steps=save_steps, warmup_steps=warmup_steps,
|
||||
neft_alpha=neft_alpha, train_on_prompt=train_on_prompt, upcast_layernorm=upcast_layernorm
|
||||
))
|
||||
|
||||
with gr.Accordion(label="LoRA config", open=False) as lora_tab:
|
||||
with gr.Row():
|
||||
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
|
||||
lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
lora_target = gr.Textbox(scale=1)
|
||||
additional_target = gr.Textbox(scale=1)
|
||||
resume_lora_training = gr.Checkbox(value=True, scale=1)
|
||||
|
||||
input_elems.update({lora_rank, lora_dropout, lora_target, additional_target, resume_lora_training})
|
||||
elem_dict.update(dict(
|
||||
lora_tab=lora_tab, lora_rank=lora_rank, lora_dropout=lora_dropout, lora_target=lora_target,
|
||||
additional_target=additional_target, resume_lora_training=resume_lora_training,
|
||||
))
|
||||
|
||||
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
|
||||
with gr.Row():
|
||||
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
|
||||
reward_model = gr.Dropdown(scale=3)
|
||||
refresh_btn = gr.Button(scale=1)
|
||||
|
||||
refresh_btn.click(
|
||||
list_checkpoint,
|
||||
[engine.manager.get_elem_by_name("top.model_name"), engine.manager.get_elem_by_name("top.finetuning_type")],
|
||||
[reward_model],
|
||||
queue=False
|
||||
)
|
||||
|
||||
input_elems.update({dpo_beta, reward_model})
|
||||
elem_dict.update(dict(rlhf_tab=rlhf_tab, dpo_beta=dpo_beta, reward_model=reward_model, refresh_btn=refresh_btn))
|
||||
|
||||
with gr.Row():
|
||||
cmd_preview_btn = gr.Button()
|
||||
start_btn = gr.Button()
|
||||
stop_btn = gr.Button()
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
with gr.Row():
|
||||
output_dir = gr.Textbox()
|
||||
|
||||
with gr.Row():
|
||||
resume_btn = gr.Checkbox(visible=False, interactive=False, value=False)
|
||||
process_bar = gr.Slider(visible=False, interactive=False)
|
||||
|
||||
with gr.Box():
|
||||
output_box = gr.Markdown()
|
||||
|
||||
with gr.Column(scale=1):
|
||||
loss_viewer = gr.Plot()
|
||||
|
||||
input_elems.add(output_dir)
|
||||
output_elems = [output_box, process_bar]
|
||||
|
||||
cmd_preview_btn.click(engine.runner.preview_train, input_elems, output_elems)
|
||||
start_btn.click(engine.runner.run_train, input_elems, output_elems)
|
||||
stop_btn.click(engine.runner.set_abort, queue=False)
|
||||
resume_btn.change(engine.runner.monitor, outputs=output_elems)
|
||||
|
||||
elem_dict.update(dict(
|
||||
cmd_preview_btn=cmd_preview_btn, start_btn=start_btn, stop_btn=stop_btn, output_dir=output_dir,
|
||||
resume_btn=resume_btn, process_bar=process_bar, output_box=output_box, loss_viewer=loss_viewer
|
||||
))
|
||||
|
||||
output_box.change(
|
||||
gen_plot,
|
||||
[
|
||||
engine.manager.get_elem_by_name("top.model_name"),
|
||||
engine.manager.get_elem_by_name("top.finetuning_type"),
|
||||
output_dir
|
||||
],
|
||||
loss_viewer,
|
||||
queue=False
|
||||
)
|
||||
|
||||
return elem_dict
|
||||
@@ -6,10 +6,12 @@ CSS = r"""
|
||||
transform: translate(-50%, -50%); /* center horizontally */
|
||||
max-width: 1000px;
|
||||
max-height: 750px;
|
||||
overflow-y: scroll !important;
|
||||
overflow-y: auto;
|
||||
background-color: var(--input-background-fill);
|
||||
flex-wrap: nowrap !important;
|
||||
border: 2px solid black !important;
|
||||
z-index: 1000;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.dark .modal-box {
|
||||
|
||||
57
src/llmtuner/webui/engine.py
Normal file
57
src/llmtuner/webui/engine.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from llmtuner.webui.chatter import WebChatModel
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.utils import get_time
|
||||
|
||||
|
||||
class Engine:
|
||||
|
||||
def __init__(self, pure_chat: Optional[bool] = False) -> None:
|
||||
self.pure_chat = pure_chat
|
||||
self.manager: "Manager" = Manager()
|
||||
self.runner: "Runner" = Runner(self.manager)
|
||||
self.chatter: "WebChatModel" = WebChatModel(manager=self.manager, lazy_init=(not pure_chat))
|
||||
|
||||
def _form_dict(self, resume_dict: Dict[str, Dict[str, Any]]):
|
||||
return {self.manager.get_elem_by_name(k): gr.update(**v) for k, v in resume_dict.items()}
|
||||
|
||||
def resume(self) -> Generator[Dict[Component, Dict[str, Any]], None, None]:
|
||||
user_config = load_config()
|
||||
lang = user_config.get("lang", None) or "en"
|
||||
|
||||
init_dict = {
|
||||
"top.lang": {"value": lang},
|
||||
"infer.chat_box": {"visible": self.chatter.loaded}
|
||||
}
|
||||
|
||||
if not self.pure_chat:
|
||||
init_dict["train.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
init_dict["eval.dataset"] = {"choices": list_dataset()["choices"]}
|
||||
|
||||
if user_config.get("last_model", None):
|
||||
init_dict["top.model_name"] = {"value": user_config["last_model"]}
|
||||
init_dict["top.model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
yield self._form_dict(init_dict)
|
||||
|
||||
if not self.pure_chat:
|
||||
if self.runner.alive:
|
||||
yield {elem: gr.update(value=value) for elem, value in self.runner.running_data.items()}
|
||||
if self.runner.do_train:
|
||||
yield self._form_dict({"train.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"eval.resume_btn": {"value": True}})
|
||||
else:
|
||||
yield self._form_dict({"train.output_dir": {"value": get_time()}})
|
||||
|
||||
def change_lang(self, lang: str) -> Dict[Component, Dict[str, Any]]:
|
||||
return {
|
||||
component: gr.update(**LOCALES[name][lang])
|
||||
for elems in self.manager.all_elems.values() for name, component in elems.items() if name in LOCALES
|
||||
}
|
||||
@@ -3,51 +3,59 @@ from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.webui.components import (
|
||||
create_top,
|
||||
create_sft_tab,
|
||||
create_train_tab,
|
||||
create_eval_tab,
|
||||
create_infer_tab,
|
||||
create_export_tab
|
||||
create_export_tab,
|
||||
create_chat_box
|
||||
)
|
||||
from llmtuner.webui.common import save_config
|
||||
from llmtuner.webui.css import CSS
|
||||
from llmtuner.webui.manager import Manager
|
||||
from llmtuner.webui.runner import Runner
|
||||
from llmtuner.webui.engine import Engine
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
require_version("gradio>=3.38.0,<4.0.0", "To fix: pip install \"gradio>=3.38.0,<4.0.0\"")
|
||||
|
||||
|
||||
def create_ui() -> gr.Blocks:
|
||||
runner = Runner()
|
||||
engine = Engine(pure_chat=False)
|
||||
|
||||
with gr.Blocks(title="Web Tuner", css=CSS) as demo:
|
||||
top_elems = create_top()
|
||||
with gr.Blocks(title="LLaMA Board", css=CSS) as demo:
|
||||
engine.manager.all_elems["top"] = create_top()
|
||||
lang: "gr.Dropdown" = engine.manager.get_elem_by_name("top.lang")
|
||||
|
||||
with gr.Tab("SFT"):
|
||||
sft_elems = create_sft_tab(top_elems, runner)
|
||||
with gr.Tab("Train"):
|
||||
engine.manager.all_elems["train"] = create_train_tab(engine)
|
||||
|
||||
with gr.Tab("Evaluate"):
|
||||
eval_elems = create_eval_tab(top_elems, runner)
|
||||
engine.manager.all_elems["eval"] = create_eval_tab(engine)
|
||||
|
||||
with gr.Tab("Chat"):
|
||||
infer_elems = create_infer_tab(top_elems)
|
||||
engine.manager.all_elems["infer"] = create_infer_tab(engine)
|
||||
|
||||
with gr.Tab("Export"):
|
||||
export_elems = create_export_tab(top_elems)
|
||||
engine.manager.all_elems["export"] = create_export_tab(engine)
|
||||
|
||||
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
|
||||
manager = Manager(elem_list)
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
demo.load(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
return demo
|
||||
|
||||
top_elems["lang"].change(
|
||||
manager.gen_label,
|
||||
[top_elems["lang"]],
|
||||
[elem for elems in elem_list for elem in elems.values()],
|
||||
)
|
||||
|
||||
def create_web_demo() -> gr.Blocks:
|
||||
engine = Engine(pure_chat=True)
|
||||
|
||||
with gr.Blocks(title="Web Demo", css=CSS) as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"])
|
||||
engine.manager.all_elems["top"] = dict(lang=lang)
|
||||
|
||||
chat_box, _, _, chat_elems = create_chat_box(engine, visible=True)
|
||||
engine.manager.all_elems["infer"] = dict(chat_box=chat_box, **chat_elems)
|
||||
|
||||
demo.load(engine.resume, outputs=engine.manager.list_elems())
|
||||
lang.change(engine.change_lang, [lang], engine.manager.list_elems(), queue=False)
|
||||
lang.input(save_config, inputs=[lang], queue=False)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
@@ -59,12 +59,12 @@ LOCALES = {
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit (optional)",
|
||||
"info": "Enable 4/8-bit model quantization."
|
||||
"label": "Quantization bit",
|
||||
"info": "Enable 4/8-bit model quantization (QLoRA)."
|
||||
},
|
||||
"zh": {
|
||||
"label": "量化等级(非必填)",
|
||||
"info": "启用 4/8 比特模型量化。"
|
||||
"label": "量化等级",
|
||||
"info": "启用 4/8 比特模型量化(QLoRA)。"
|
||||
}
|
||||
},
|
||||
"template": {
|
||||
@@ -77,7 +77,7 @@ LOCALES = {
|
||||
"info": "构建提示词时使用的模板"
|
||||
}
|
||||
},
|
||||
"source_prefix": {
|
||||
"system_prompt": {
|
||||
"en": {
|
||||
"label": "System prompt (optional)",
|
||||
"info": "A sequence used as the default system prompt."
|
||||
@@ -87,6 +87,48 @@ LOCALES = {
|
||||
"info": "默认使用的系统提示词"
|
||||
}
|
||||
},
|
||||
"llama_tab": {
|
||||
"en": {
|
||||
"label": "Model configurations (LLaMA only)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "模型设置(仅LLaMA)"
|
||||
}
|
||||
},
|
||||
"flash_attn": {
|
||||
"en": {
|
||||
"label": "Use FlashAttention-2"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 FlashAttention-2"
|
||||
}
|
||||
},
|
||||
"shift_attn": {
|
||||
"en": {
|
||||
"label": "Use shift short attention (S^2-Attn)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "使用 shift short attention (S^2-Attn)"
|
||||
}
|
||||
},
|
||||
"rope_scaling": {
|
||||
"en": {
|
||||
"label": "RoPE scaling"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RoPE 插值方法"
|
||||
}
|
||||
},
|
||||
"training_stage": {
|
||||
"en": {
|
||||
"label": "Stage",
|
||||
"info": "The stage to perform in training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "训练阶段",
|
||||
"info": "目前采用的训练方式。"
|
||||
}
|
||||
},
|
||||
"dataset_dir": {
|
||||
"en": {
|
||||
"label": "Data dir",
|
||||
@@ -105,12 +147,12 @@ LOCALES = {
|
||||
"label": "数据集"
|
||||
}
|
||||
},
|
||||
"preview_btn": {
|
||||
"data_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview"
|
||||
"value": "Preview dataset"
|
||||
},
|
||||
"zh": {
|
||||
"value": "预览"
|
||||
"value": "预览数据集"
|
||||
}
|
||||
},
|
||||
"preview_count": {
|
||||
@@ -121,12 +163,28 @@ LOCALES = {
|
||||
"label": "数量"
|
||||
}
|
||||
},
|
||||
"preview_samples": {
|
||||
"page_index": {
|
||||
"en": {
|
||||
"label": "Samples"
|
||||
"label": "Page"
|
||||
},
|
||||
"zh": {
|
||||
"label": "样例"
|
||||
"label": "页数"
|
||||
}
|
||||
},
|
||||
"prev_btn": {
|
||||
"en": {
|
||||
"value": "Prev"
|
||||
},
|
||||
"zh": {
|
||||
"value": "上一页"
|
||||
}
|
||||
},
|
||||
"next_btn": {
|
||||
"en": {
|
||||
"value": "Next"
|
||||
},
|
||||
"zh": {
|
||||
"value": "下一页"
|
||||
}
|
||||
},
|
||||
"close_btn": {
|
||||
@@ -137,24 +195,22 @@ LOCALES = {
|
||||
"value": "关闭"
|
||||
}
|
||||
},
|
||||
"max_source_length": {
|
||||
"preview_samples": {
|
||||
"en": {
|
||||
"label": "Max source length",
|
||||
"info": "Max tokens in source sequence."
|
||||
"label": "Samples"
|
||||
},
|
||||
"zh": {
|
||||
"label": "输入序列最大长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
"label": "样例"
|
||||
}
|
||||
},
|
||||
"max_target_length": {
|
||||
"cutoff_len": {
|
||||
"en": {
|
||||
"label": "Max target length",
|
||||
"info": "Max tokens in target sequence."
|
||||
"label": "Cutoff length",
|
||||
"info": "Max tokens in input sequence."
|
||||
},
|
||||
"zh": {
|
||||
"label": "输出序列最大长度",
|
||||
"info": "输出序列分词后的最大长度。"
|
||||
"label": "截断长度",
|
||||
"info": "输入序列分词后的最大长度。"
|
||||
}
|
||||
},
|
||||
"learning_rate": {
|
||||
@@ -187,6 +243,16 @@ LOCALES = {
|
||||
"info": "每个数据集最多使用的样本数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
}
|
||||
},
|
||||
"batch_size": {
|
||||
"en": {
|
||||
"label": "Batch size",
|
||||
@@ -227,9 +293,9 @@ LOCALES = {
|
||||
"info": "用于梯度裁剪的范数。"
|
||||
}
|
||||
},
|
||||
"dev_ratio": {
|
||||
"val_size": {
|
||||
"en": {
|
||||
"label": "Dev ratio",
|
||||
"label": "Val size",
|
||||
"info": "Proportion of data in the dev set."
|
||||
},
|
||||
"zh": {
|
||||
@@ -267,14 +333,34 @@ LOCALES = {
|
||||
"info": "学习率预热采用的步数。"
|
||||
}
|
||||
},
|
||||
"compute_type": {
|
||||
"neft_alpha": {
|
||||
"en": {
|
||||
"label": "Compute type",
|
||||
"info": "Whether to use fp16 or bf16 mixed precision training."
|
||||
"label": "NEFTune Alpha",
|
||||
"info": "Magnitude of noise adding to embedding vectors."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算类型",
|
||||
"info": "是否启用 FP16 或 BF16 混合精度训练。"
|
||||
"label": "NEFTune 噪声参数",
|
||||
"info": "嵌入向量所添加的噪声大小。"
|
||||
}
|
||||
},
|
||||
"train_on_prompt": {
|
||||
"en": {
|
||||
"label": "Train on prompt",
|
||||
"info": "Compute loss on the prompt tokens in supervised fine-tuning."
|
||||
},
|
||||
"zh": {
|
||||
"label": "计算输入损失",
|
||||
"info": "在监督微调时候计算输入序列的损失。"
|
||||
}
|
||||
},
|
||||
"upcast_layernorm": {
|
||||
"en": {
|
||||
"label": "Upcast LayerNorm",
|
||||
"info": "Upcast weights of layernorm in float32."
|
||||
},
|
||||
"zh": {
|
||||
"label": "缩放归一化层",
|
||||
"info": "将归一化层权重缩放至 32 位浮点数。"
|
||||
}
|
||||
},
|
||||
"lora_tab": {
|
||||
@@ -308,11 +394,67 @@ LOCALES = {
|
||||
"lora_target": {
|
||||
"en": {
|
||||
"label": "LoRA modules (optional)",
|
||||
"info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
"info": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "LoRA 作用层(非必填)",
|
||||
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
|
||||
"label": "LoRA 作用模块(非必填)",
|
||||
"info": "应用 LoRA 的目标模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"additional_target": {
|
||||
"en": {
|
||||
"label": "Additional modules (optional)",
|
||||
"info": "Name(s) of modules apart from LoRA layers to be set as trainable. Use commas to separate multiple modules."
|
||||
},
|
||||
"zh": {
|
||||
"label": "附加模块(非必填)",
|
||||
"info": "除 LoRA 层以外的可训练模块名称。使用英文逗号分隔多个名称。"
|
||||
}
|
||||
},
|
||||
"resume_lora_training": {
|
||||
"en": {
|
||||
"label": "Resume LoRA training",
|
||||
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
|
||||
},
|
||||
"zh": {
|
||||
"label": "继续上次的训练",
|
||||
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
|
||||
}
|
||||
},
|
||||
"rlhf_tab": {
|
||||
"en": {
|
||||
"label": "RLHF configurations"
|
||||
},
|
||||
"zh": {
|
||||
"label": "RLHF 参数设置"
|
||||
}
|
||||
},
|
||||
"dpo_beta": {
|
||||
"en": {
|
||||
"label": "DPO beta",
|
||||
"info": "Value of the beta parameter in the DPO loss."
|
||||
},
|
||||
"zh": {
|
||||
"label": "DPO beta 参数",
|
||||
"info": "DPO 损失函数中 beta 超参数大小。"
|
||||
}
|
||||
},
|
||||
"reward_model": {
|
||||
"en": {
|
||||
"label": "Reward model",
|
||||
"info": "Checkpoint of the reward model for PPO training. (Needs to refresh checkpoints)"
|
||||
},
|
||||
"zh": {
|
||||
"label": "奖励模型",
|
||||
"info": "PPO 训练中奖励模型的断点路径。(需要刷新断点)"
|
||||
}
|
||||
},
|
||||
"cmd_preview_btn": {
|
||||
"en": {
|
||||
"value": "Preview command"
|
||||
},
|
||||
"zh": {
|
||||
"value": "预览命令"
|
||||
}
|
||||
},
|
||||
"start_btn": {
|
||||
@@ -389,7 +531,7 @@ LOCALES = {
|
||||
"value": "模型未加载,请先加载模型。"
|
||||
}
|
||||
},
|
||||
"prefix": {
|
||||
"system": {
|
||||
"en": {
|
||||
"placeholder": "System prompt (optional)"
|
||||
},
|
||||
@@ -453,7 +595,7 @@ LOCALES = {
|
||||
"label": "温度系数"
|
||||
}
|
||||
},
|
||||
"save_dir": {
|
||||
"export_dir": {
|
||||
"en": {
|
||||
"label": "Export dir",
|
||||
"info": "Directory to save exported model."
|
||||
@@ -509,10 +651,14 @@ ALERTS = {
|
||||
"en": "Please select a checkpoint.",
|
||||
"zh": "请选择断点。"
|
||||
},
|
||||
"err_no_save_dir": {
|
||||
"err_no_export_dir": {
|
||||
"en": "Please provide export dir.",
|
||||
"zh": "请填写导出目录"
|
||||
},
|
||||
"err_failed": {
|
||||
"en": "Failed.",
|
||||
"zh": "训练出错。"
|
||||
},
|
||||
"info_aborting": {
|
||||
"en": "Aborted, wait for terminating...",
|
||||
"zh": "训练中断,正在等待线程结束……"
|
||||
|
||||
@@ -1,35 +1,35 @@
|
||||
import gradio as gr
|
||||
from gradio.components import Component
|
||||
from typing import Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Dict, List, Set
|
||||
|
||||
from llmtuner.webui.common import get_model_path, list_dataset, load_config
|
||||
from llmtuner.webui.locales import LOCALES
|
||||
from llmtuner.webui.utils import get_time
|
||||
if TYPE_CHECKING:
|
||||
from gradio.components import Component
|
||||
|
||||
|
||||
class Manager:
|
||||
|
||||
def __init__(self, elem_list: List[Dict[str, Component]]):
|
||||
self.elem_list = elem_list
|
||||
def __init__(self) -> None:
|
||||
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
||||
|
||||
def gen_refresh(self) -> Dict[str, Any]:
|
||||
refresh_dict = {
|
||||
"dataset": {"choices": list_dataset()["choices"]},
|
||||
"output_dir": {"value": get_time()}
|
||||
def get_elem_by_name(self, name: str) -> "Component":
|
||||
r"""
|
||||
Example: top.lang, train.dataset
|
||||
"""
|
||||
tab_name, elem_name = name.split(".")
|
||||
return self.all_elems[tab_name][elem_name]
|
||||
|
||||
def get_base_elems(self) -> Set["Component"]:
|
||||
return {
|
||||
self.all_elems["top"]["lang"],
|
||||
self.all_elems["top"]["model_name"],
|
||||
self.all_elems["top"]["model_path"],
|
||||
self.all_elems["top"]["checkpoints"],
|
||||
self.all_elems["top"]["finetuning_type"],
|
||||
self.all_elems["top"]["quantization_bit"],
|
||||
self.all_elems["top"]["template"],
|
||||
self.all_elems["top"]["system_prompt"],
|
||||
self.all_elems["top"]["flash_attn"],
|
||||
self.all_elems["top"]["shift_attn"],
|
||||
self.all_elems["top"]["rope_scaling"]
|
||||
}
|
||||
user_config = load_config()
|
||||
if user_config["last_model"]:
|
||||
refresh_dict["model_name"] = {"value": user_config["last_model"]}
|
||||
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
|
||||
|
||||
return refresh_dict
|
||||
|
||||
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
|
||||
update_dict = {}
|
||||
refresh_dict = self.gen_refresh()
|
||||
|
||||
for elems in self.elem_list:
|
||||
for name, component in elems.items():
|
||||
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
|
||||
|
||||
return update_dict
|
||||
def list_elems(self) -> List["Component"]:
|
||||
return [elem for elems in self.all_elems.values() for elem in elems.values()]
|
||||
|
||||
@@ -1,238 +1,254 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import gradio as gr
|
||||
from threading import Thread
|
||||
from gradio.components import Component # cannot use TYPE_CHECKING here
|
||||
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
|
||||
|
||||
import transformers
|
||||
from typing import Generator, List, Optional, Tuple
|
||||
from transformers.trainer import TRAINING_ARGS_NAME
|
||||
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
from llmtuner.extras.constants import DEFAULT_MODULE
|
||||
from llmtuner.extras.constants import TRAINING_STAGES
|
||||
from llmtuner.extras.logging import LoggerHandler
|
||||
from llmtuner.extras.misc import torch_gc
|
||||
from llmtuner.tuner import get_train_args, run_sft
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir
|
||||
from llmtuner.tuner import run_exp
|
||||
from llmtuner.webui.common import get_module, get_save_dir, load_config
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.utils import format_info, get_eval_results
|
||||
from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
class Runner:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, manager: "Manager") -> None:
|
||||
self.manager = manager
|
||||
""" Resume """
|
||||
self.thread: "Thread" = None
|
||||
self.do_train = True
|
||||
self.running_data: Dict["Component", Any] = None
|
||||
self.monitor_inputs: Dict[str, str] = None
|
||||
""" State """
|
||||
self.aborted = False
|
||||
self.running = False
|
||||
""" Handler """
|
||||
self.logger_handler = LoggerHandler()
|
||||
self.logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(self.logger_handler)
|
||||
transformers.logging.add_handler(self.logger_handler)
|
||||
|
||||
def set_abort(self):
|
||||
@property
|
||||
def alive(self) -> bool:
|
||||
return self.thread is not None
|
||||
|
||||
def set_abort(self) -> None:
|
||||
self.aborted = True
|
||||
self.running = False
|
||||
|
||||
def initialize(
|
||||
self, lang: str, model_name: str, dataset: List[str]
|
||||
) -> Tuple[str, str, LoggerHandler, LogCallback]:
|
||||
def _initialize(self, data: Dict[Component, Any], do_train: bool) -> str:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
lang, model_name, model_path = get("top.lang"), get("top.model_name"), get("top.model_path")
|
||||
dataset = get("train.dataset") if do_train else get("eval.dataset")
|
||||
|
||||
if self.running:
|
||||
return None, ALERTS["err_conflict"][lang], None, None
|
||||
return ALERTS["err_conflict"][lang]
|
||||
|
||||
if not model_name:
|
||||
return None, ALERTS["err_no_model"][lang], None, None
|
||||
return ALERTS["err_no_model"][lang]
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
return None, ALERTS["err_no_path"][lang], None, None
|
||||
if not model_path:
|
||||
return ALERTS["err_no_path"][lang]
|
||||
|
||||
if len(dataset) == 0:
|
||||
return None, ALERTS["err_no_dataset"][lang], None, None
|
||||
return ALERTS["err_no_dataset"][lang]
|
||||
|
||||
self.aborted = False
|
||||
self.running = True
|
||||
self.logger_handler.reset()
|
||||
self.trainer_callback = LogCallback(self)
|
||||
return ""
|
||||
|
||||
logger_handler = LoggerHandler()
|
||||
logger_handler.setLevel(logging.INFO)
|
||||
logging.root.addHandler(logger_handler)
|
||||
transformers.logging.add_handler(logger_handler)
|
||||
trainer_callback = LogCallback(self)
|
||||
|
||||
return model_name_or_path, "", logger_handler, trainer_callback
|
||||
|
||||
def finalize(
|
||||
self, lang: str, finish_info: Optional[str] = None
|
||||
) -> str:
|
||||
def _finalize(self, lang: str, finish_info: str) -> str:
|
||||
self.thread = None
|
||||
self.running = False
|
||||
torch_gc()
|
||||
if self.aborted:
|
||||
return ALERTS["info_aborted"][lang]
|
||||
else:
|
||||
return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
|
||||
return finish_info
|
||||
|
||||
def run_train(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
learning_rate: str,
|
||||
num_train_epochs: str,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
gradient_accumulation_steps: int,
|
||||
lr_scheduler_type: str,
|
||||
max_grad_norm: str,
|
||||
dev_ratio: float,
|
||||
logging_steps: int,
|
||||
save_steps: int,
|
||||
warmup_steps: int,
|
||||
compute_type: str,
|
||||
lora_rank: int,
|
||||
lora_dropout: float,
|
||||
lora_target: str,
|
||||
output_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
def _parse_train_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
stage=TRAINING_STAGES[get("train.training_stage")],
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_train=True,
|
||||
overwrite_cache=True,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
learning_rate=float(learning_rate),
|
||||
num_train_epochs=float(num_train_epochs),
|
||||
max_samples=int(max_samples),
|
||||
per_device_train_batch_size=batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
lr_scheduler_type=lr_scheduler_type,
|
||||
max_grad_norm=float(max_grad_norm),
|
||||
logging_steps=logging_steps,
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
fp16=(compute_type == "fp16"),
|
||||
bf16=(compute_type == "bf16"),
|
||||
lora_rank=lora_rank,
|
||||
lora_dropout=lora_dropout,
|
||||
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
|
||||
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("train.dataset_dir"),
|
||||
dataset=",".join(get("train.dataset")),
|
||||
cutoff_len=get("train.cutoff_len"),
|
||||
learning_rate=float(get("train.learning_rate")),
|
||||
num_train_epochs=float(get("train.num_train_epochs")),
|
||||
max_samples=int(get("train.max_samples")),
|
||||
per_device_train_batch_size=get("train.batch_size"),
|
||||
gradient_accumulation_steps=get("train.gradient_accumulation_steps"),
|
||||
lr_scheduler_type=get("train.lr_scheduler_type"),
|
||||
max_grad_norm=float(get("train.max_grad_norm")),
|
||||
logging_steps=get("train.logging_steps"),
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neft_alpha=get("train.neft_alpha"),
|
||||
train_on_prompt=get("train.train_on_prompt"),
|
||||
upcast_layernorm=get("train.upcast_layernorm"),
|
||||
lora_rank=get("train.lora_rank"),
|
||||
lora_dropout=get("train.lora_dropout"),
|
||||
lora_target=get("train.lora_target") or get_module(get("top.model_name")),
|
||||
additional_target=get("train.additional_target") if get("train.additional_target") else None,
|
||||
resume_lora_training=get("train.resume_lora_training"),
|
||||
output_dir=get_save_dir(get("top.model_name"), get("top.finetuning_type"), get("train.output_dir"))
|
||||
)
|
||||
args[get("train.compute_type")] = True
|
||||
args["disable_tqdm"] = True
|
||||
|
||||
if dev_ratio > 1e-6:
|
||||
args["dev_ratio"] = dev_ratio
|
||||
if TRAINING_STAGES[get("train.training_stage")] in ["rm", "ppo", "dpo"]:
|
||||
args["resume_lora_training"] = (args["quantization_bit"] is not None)
|
||||
|
||||
if args["quantization_bit"] is not None:
|
||||
args["upcast_layernorm"] = True
|
||||
|
||||
if args["stage"] == "ppo":
|
||||
args["reward_model"] = get("train.reward_model")
|
||||
|
||||
if args["stage"] == "dpo":
|
||||
args["dpo_beta"] = get("train.dpo_beta")
|
||||
|
||||
if get("train.val_size") > 1e-6 and args["stage"] != "ppo":
|
||||
args["val_size"] = get("train.val_size")
|
||||
args["evaluation_strategy"] = "steps"
|
||||
args["eval_steps"] = save_steps
|
||||
args["eval_steps"] = get("train.save_steps")
|
||||
args["load_best_model_at_end"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
return args
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
def _parse_eval_args(self, data: Dict[Component, Any]) -> Dict[str, Any]:
|
||||
get = lambda name: data[self.manager.get_elem_by_name(name)]
|
||||
user_config = load_config()
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
|
||||
yield self.finalize(lang)
|
||||
|
||||
def run_eval(
|
||||
self,
|
||||
lang: str,
|
||||
model_name: str,
|
||||
checkpoints: List[str],
|
||||
finetuning_type: str,
|
||||
quantization_bit: str,
|
||||
template: str,
|
||||
source_prefix: str,
|
||||
dataset_dir: str,
|
||||
dataset: List[str],
|
||||
max_source_length: int,
|
||||
max_target_length: int,
|
||||
max_samples: str,
|
||||
batch_size: int,
|
||||
predict: bool
|
||||
) -> Generator[str, None, None]:
|
||||
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
|
||||
if error:
|
||||
yield error
|
||||
return
|
||||
|
||||
if checkpoints:
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
if get("top.checkpoints"):
|
||||
checkpoint_dir = ",".join([
|
||||
get_save_dir(get("top.model_name"), get("top.finetuning_type"), ckpt) for ckpt in get("top.checkpoints")
|
||||
])
|
||||
output_dir = get_save_dir(
|
||||
get("top.model_name"), get("top.finetuning_type"), "eval_" + "_".join(get("top.checkpoints"))
|
||||
)
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
|
||||
else:
|
||||
checkpoint_dir = None
|
||||
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
|
||||
output_dir = get_save_dir(get("top.model_name"), get("top.finetuning_type"), "eval_base")
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
stage="sft",
|
||||
model_name_or_path=get("top.model_path"),
|
||||
do_eval=True,
|
||||
overwrite_cache=True,
|
||||
predict_with_generate=True,
|
||||
cache_dir=user_config.get("cache_dir", None),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type,
|
||||
quantization_bit=int(quantization_bit) if quantization_bit else None,
|
||||
template=template,
|
||||
source_prefix=source_prefix,
|
||||
dataset_dir=dataset_dir,
|
||||
dataset=",".join(dataset),
|
||||
max_source_length=max_source_length,
|
||||
max_target_length=max_target_length,
|
||||
max_samples=int(max_samples),
|
||||
per_device_eval_batch_size=batch_size,
|
||||
finetuning_type=get("top.finetuning_type"),
|
||||
quantization_bit=int(get("top.quantization_bit")) if get("top.quantization_bit") in ["8", "4"] else None,
|
||||
template=get("top.template"),
|
||||
system_prompt=get("top.system_prompt"),
|
||||
flash_attn=get("top.flash_attn"),
|
||||
shift_attn=get("top.shift_attn"),
|
||||
rope_scaling=get("top.rope_scaling") if get("top.rope_scaling") in ["linear", "dynamic"] else None,
|
||||
dataset_dir=get("eval.dataset_dir"),
|
||||
dataset=",".join(get("eval.dataset")),
|
||||
cutoff_len=get("eval.cutoff_len"),
|
||||
max_samples=int(get("eval.max_samples")),
|
||||
per_device_eval_batch_size=get("eval.batch_size"),
|
||||
max_new_tokens=get("eval.max_new_tokens"),
|
||||
top_p=get("eval.top_p"),
|
||||
temperature=get("eval.temperature"),
|
||||
output_dir=output_dir
|
||||
)
|
||||
|
||||
if predict:
|
||||
if get("eval.predict"):
|
||||
args.pop("do_eval", None)
|
||||
args["do_predict"] = True
|
||||
|
||||
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
|
||||
return args
|
||||
|
||||
run_args = dict(
|
||||
model_args=model_args,
|
||||
data_args=data_args,
|
||||
training_args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
callbacks=[trainer_callback]
|
||||
)
|
||||
thread = threading.Thread(target=run_sft, kwargs=run_args)
|
||||
thread.start()
|
||||
def _preview(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
yield gen_cmd(args), gr.update(visible=False)
|
||||
|
||||
while thread.is_alive():
|
||||
time.sleep(1)
|
||||
def _launch(self, data: Dict[Component, Any], do_train: bool) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
error = self._initialize(data, do_train)
|
||||
if error:
|
||||
gr.Warning(error)
|
||||
yield error, gr.update(visible=False)
|
||||
else:
|
||||
args = self._parse_train_args(data) if do_train else self._parse_eval_args(data)
|
||||
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
|
||||
self.running = True
|
||||
self.do_train, self.running_data = do_train, data
|
||||
self.monitor_inputs = dict(lang=data[self.manager.get_elem_by_name("top.lang")], output_dir=args["output_dir"])
|
||||
self.thread = Thread(target=run_exp, kwargs=run_kwargs)
|
||||
self.thread.start()
|
||||
yield from self.monitor()
|
||||
|
||||
def preview_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._preview(data, do_train=True)
|
||||
|
||||
def preview_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._preview(data, do_train=False)
|
||||
|
||||
def run_train(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._launch(data, do_train=True)
|
||||
|
||||
def run_eval(self, data: Dict[Component, Any]) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
yield from self._launch(data, do_train=False)
|
||||
|
||||
def monitor(self) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
|
||||
lang, output_dir = self.monitor_inputs["lang"], self.monitor_inputs["output_dir"]
|
||||
while self.thread.is_alive():
|
||||
time.sleep(2)
|
||||
if self.aborted:
|
||||
yield ALERTS["info_aborting"][lang]
|
||||
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
|
||||
else:
|
||||
yield format_info(logger_handler.log, trainer_callback.tracker)
|
||||
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
|
||||
|
||||
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
|
||||
if self.do_train:
|
||||
if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
|
||||
finish_info = ALERTS["info_finished"][lang]
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
else:
|
||||
if os.path.exists(os.path.join(output_dir, "all_results.json")):
|
||||
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
|
||||
else:
|
||||
finish_info = ALERTS["err_failed"][lang]
|
||||
|
||||
yield self._finalize(lang, finish_info), gr.update(visible=False)
|
||||
|
||||
@@ -3,57 +3,53 @@ import json
|
||||
import gradio as gr
|
||||
import matplotlib.figure
|
||||
import matplotlib.pyplot as plt
|
||||
from typing import Any, Dict, Generator, List, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
from datetime import datetime
|
||||
|
||||
from llmtuner.extras.ploting import smooth
|
||||
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer
|
||||
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
|
||||
from llmtuner.webui.locales import ALERTS
|
||||
from llmtuner.webui.common import get_save_dir
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llmtuner.extras.callbacks import LogCallback
|
||||
|
||||
|
||||
def format_info(log: str, tracker: dict) -> str:
|
||||
info = log
|
||||
if "current_steps" in tracker:
|
||||
info += "Running **{:d}/{:d}**: {} < {}\n".format(
|
||||
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
|
||||
)
|
||||
return info
|
||||
def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
|
||||
if not callback.max_steps:
|
||||
return gr.update(visible=False)
|
||||
|
||||
percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
|
||||
label = "Running {:d}/{:d}: {} < {}".format(
|
||||
callback.cur_steps,
|
||||
callback.max_steps,
|
||||
callback.elapsed_time,
|
||||
callback.remaining_time
|
||||
)
|
||||
return gr.update(label=label, value=percentage, visible=True)
|
||||
|
||||
|
||||
def get_time() -> str:
|
||||
return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
|
||||
|
||||
|
||||
def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
if (
|
||||
len(dataset) > 0
|
||||
and "file_name" in dataset_info[dataset[0]]
|
||||
and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
|
||||
):
|
||||
return gr.update(interactive=True)
|
||||
else:
|
||||
return gr.update(interactive=False)
|
||||
|
||||
|
||||
def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
|
||||
with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
|
||||
dataset_info = json.load(f)
|
||||
data_file = dataset_info[dataset[0]]["file_name"]
|
||||
with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return len(data), data[:2], gr.update(visible=True)
|
||||
|
||||
|
||||
def can_quantize(finetuning_type: str) -> Dict[str, Any]:
|
||||
if finetuning_type != "lora":
|
||||
return gr.update(value="", interactive=False)
|
||||
return gr.update(value="None", interactive=False)
|
||||
else:
|
||||
return gr.update(interactive=True)
|
||||
|
||||
|
||||
def gen_cmd(args: Dict[str, Any]) -> str:
|
||||
args.pop("disable_tqdm", None)
|
||||
args["plot_loss"] = args.get("do_train", None)
|
||||
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python src/train_bash.py "]
|
||||
for k, v in args.items():
|
||||
if v is not None and v != "":
|
||||
cmd_lines.append(" --{} {} ".format(k, str(v)))
|
||||
cmd_text = "\\\n".join(cmd_lines)
|
||||
cmd_text = "```bash\n{}\n```".format(cmd_text)
|
||||
return cmd_text
|
||||
|
||||
|
||||
def get_eval_results(path: os.PathLike) -> str:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
result = json.dumps(json.load(f), indent=4)
|
||||
@@ -61,9 +57,11 @@ def get_eval_results(path: os.PathLike) -> str:
|
||||
|
||||
|
||||
def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
|
||||
log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
if not base_model:
|
||||
return
|
||||
log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl")
|
||||
if not os.path.isfile(log_file):
|
||||
return None
|
||||
return
|
||||
|
||||
plt.close("all")
|
||||
fig = plt.figure()
|
||||
@@ -85,41 +83,3 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
|
||||
ax.set_xlabel("step")
|
||||
ax.set_ylabel("loss")
|
||||
return fig
|
||||
|
||||
|
||||
def export_model(
|
||||
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
|
||||
) -> Generator[str, None, None]:
|
||||
if not model_name:
|
||||
yield ALERTS["err_no_model"][lang]
|
||||
return
|
||||
|
||||
model_name_or_path = get_model_path(model_name)
|
||||
if not model_name_or_path:
|
||||
yield ALERTS["err_no_path"][lang]
|
||||
return
|
||||
|
||||
if not checkpoints:
|
||||
yield ALERTS["err_no_checkpoint"][lang]
|
||||
return
|
||||
|
||||
checkpoint_dir = ",".join(
|
||||
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
|
||||
)
|
||||
|
||||
if not save_dir:
|
||||
yield ALERTS["err_no_save_dir"][lang]
|
||||
return
|
||||
|
||||
args = dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
finetuning_type=finetuning_type
|
||||
)
|
||||
|
||||
yield ALERTS["info_exporting"][lang]
|
||||
model_args, _, finetuning_args, _ = get_infer_args(args)
|
||||
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
yield ALERTS["info_exported"][lang]
|
||||
|
||||
@@ -1,17 +1,8 @@
|
||||
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo
|
||||
from llmtuner import run_exp
|
||||
|
||||
|
||||
def main():
|
||||
model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
|
||||
|
||||
if general_args.stage == "pt":
|
||||
run_pt(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "sft":
|
||||
run_sft(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "rm":
|
||||
run_rm(model_args, data_args, training_args, finetuning_args)
|
||||
elif general_args.stage == "ppo":
|
||||
run_ppo(model_args, data_args, training_args, finetuning_args)
|
||||
run_exp()
|
||||
|
||||
|
||||
def _mp_fn(index):
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from llmtuner.webui.interface import create_ui
|
||||
from llmtuner import create_ui
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
# coding=utf-8
|
||||
# Implements user interface in browser for fine-tuned models.
|
||||
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
|
||||
|
||||
import gradio as gr
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from llmtuner.tuner import get_infer_args
|
||||
from llmtuner.webui.chat import WebChatModel
|
||||
from llmtuner.webui.components.chatbot import create_chat_box
|
||||
from llmtuner.webui.manager import Manager
|
||||
|
||||
|
||||
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
|
||||
from llmtuner import create_web_demo
|
||||
|
||||
|
||||
def main():
|
||||
chat_model = WebChatModel(*get_infer_args())
|
||||
|
||||
with gr.Blocks(title="Web Demo") as demo:
|
||||
lang = gr.Dropdown(choices=["en", "zh"], value="en")
|
||||
|
||||
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
|
||||
|
||||
manager = Manager([{"lang": lang}, chat_elems])
|
||||
|
||||
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
|
||||
|
||||
demo = create_web_demo()
|
||||
demo.queue()
|
||||
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
|
||||
|
||||
|
||||
44
tests/cal_flops.py
Normal file
44
tests/cal_flops.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# coding=utf-8
|
||||
# Calculates the flops of pre-trained models.
|
||||
# Usage: python cal_flops.py --model_name_or_path path_to_model --batch_size 1 --seq_length 512
|
||||
# Inspired by: https://www.deepspeed.ai/tutorials/flops-profiler/
|
||||
|
||||
import fire
|
||||
import torch
|
||||
from typing import Optional
|
||||
from deepspeed.accelerator import get_accelerator # type: ignore
|
||||
from deepspeed.profiling.flops_profiler import get_model_profile # type: ignore
|
||||
|
||||
from llmtuner import ChatModel
|
||||
|
||||
|
||||
def calculate(
|
||||
model_name_or_path: str,
|
||||
batch_size: Optional[int] = 1,
|
||||
seq_length: Optional[int] = 256,
|
||||
flash_attn: Optional[bool] = False
|
||||
):
|
||||
with get_accelerator().device(0):
|
||||
chat_model = ChatModel(dict(
|
||||
model_name_or_path=model_name_or_path,
|
||||
template="vanilla",
|
||||
flash_attn=flash_attn
|
||||
))
|
||||
fake_input = torch.ones((batch_size, seq_length), dtype=torch.long, device=chat_model.model.device)
|
||||
input_dict = {
|
||||
"input_ids": fake_input,
|
||||
"labels": fake_input.clone()
|
||||
}
|
||||
flops, macs, params = get_model_profile(
|
||||
chat_model.model,
|
||||
kwargs=input_dict,
|
||||
print_profile=True,
|
||||
detailed=True
|
||||
)
|
||||
print("FLOPs:", flops)
|
||||
print("MACs:", macs)
|
||||
print("Params:", params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(calculate)
|
||||
@@ -1,133 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Evaluates fine-tuned models automatically.
|
||||
# Usage: python evaluate_zh.py --evalset ceval/ceval-exam:law --split dev --output_file result.json
|
||||
# --api_base http://localhost:8000/v1 --task_type choice --n_samples 100
|
||||
# dataset format: question (string), A (string), B (string), C (string), D (string), answer (Literal["A", "B", "C", "D"])
|
||||
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import openai
|
||||
from tqdm import tqdm
|
||||
from typing import Literal, Optional
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def format_example_choice(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "请从ABCD四个选项中选出正确的选项,仅输出选项序号。\n{question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(
|
||||
question=examples["question"][i],
|
||||
A=examples["A"][i],
|
||||
B=examples["B"][i],
|
||||
C=examples["C"][i],
|
||||
D=examples["D"][i]
|
||||
)
|
||||
label = examples["answer"][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
def format_example_cloze(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "请选择正确的答案填空,仅输出正确的选项。\n{question}\n选项:{A}\n{B}\n{C}\n{D}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(
|
||||
question=examples["question"][i],
|
||||
A=examples["A"][i],
|
||||
B=examples["B"][i],
|
||||
C=examples["C"][i],
|
||||
D=examples["D"][i]
|
||||
)
|
||||
label = examples[examples["answer"][i]][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
def format_example_openqa(examples):
|
||||
model_inputs = {"query": [], "label": []}
|
||||
task_template = "回答以下问题:{question}\n答案:"
|
||||
for i in range(len(examples["id"])):
|
||||
query = task_template.format(question=examples["question"][i])
|
||||
label = examples[examples["answer"][i]][i]
|
||||
model_inputs["query"].append(query)
|
||||
model_inputs["label"].append(label)
|
||||
return model_inputs
|
||||
|
||||
|
||||
TASK_DICT = {
|
||||
"choice": format_example_choice,
|
||||
"cloze": format_example_cloze,
|
||||
"openqa": format_example_openqa
|
||||
}
|
||||
|
||||
|
||||
EXT2TYPE = {
|
||||
"csv": "csv",
|
||||
"json": "json",
|
||||
"jsonl": "json"
|
||||
}
|
||||
|
||||
|
||||
def evaluate(
|
||||
evalset: str,
|
||||
api_base: str,
|
||||
output_file: str,
|
||||
split: Optional[str] = "val",
|
||||
task_type: Optional[Literal["choice", "cloze", "openqa"]] = "choice",
|
||||
n_samples: Optional[int] = 20
|
||||
):
|
||||
|
||||
openai.api_base = api_base
|
||||
openai.api_key = "none"
|
||||
|
||||
if os.path.isfile(evalset):
|
||||
dataset = load_dataset(EXT2TYPE[evalset.split(".")[-1]], data_files=evalset)["train"]
|
||||
elif ":" in evalset:
|
||||
evalset, subset = evalset.split(":")
|
||||
dataset = load_dataset(evalset, subset, split=split)
|
||||
else:
|
||||
dataset = load_dataset(evalset, split=split)
|
||||
|
||||
n_samples = min(len(dataset), n_samples)
|
||||
|
||||
dataset = dataset.map(TASK_DICT[task_type], batched=True)
|
||||
dataset = dataset.select(range(n_samples))
|
||||
|
||||
n_correct = 0
|
||||
predictions = []
|
||||
for example in tqdm(dataset):
|
||||
query, label = example["query"], example["label"]
|
||||
predict = openai.ChatCompletion.create(
|
||||
model="default",
|
||||
messages=[{"role": "user", "content": query}],
|
||||
temperature=0.01,
|
||||
top_p=0.01,
|
||||
max_new_tokens=20
|
||||
).choices[0].message.content
|
||||
|
||||
if task_type == "choice" and predict[0].lower() == label[0].lower():
|
||||
n_correct += 1
|
||||
if task_type == "cloze" and label in [predict[:len(label)], predict[-len(label):]]:
|
||||
n_correct += 1
|
||||
if task_type == "openqa" and label in predict:
|
||||
n_correct += 1
|
||||
|
||||
predictions.append({
|
||||
"query": query,
|
||||
"label": label,
|
||||
"predict": predict
|
||||
})
|
||||
|
||||
print("Result: {}/{}\nAccuracy: {:.2f}%".format(n_correct, n_samples, n_correct / n_samples * 100))
|
||||
|
||||
with open(output_file, "w", encoding="utf-8") as f:
|
||||
json.dump(predictions, f, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(evaluate)
|
||||
86
tests/llamafy_baichuan2.py
Normal file
86
tests/llamafy_baichuan2.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# coding=utf-8
|
||||
# Converts the Baichuan2-7B model in the same format as LLaMA2-7B.
|
||||
# Usage: python llamafy_baichuan2.py --input_dir input --output_dir output --shard_size 10GB
|
||||
# Inspired by: https://huggingface.co/fireballoon/baichuan-llama-7b/blob/main/convert_baichuan_to_llama.py
|
||||
# Converted model: https://huggingface.co/hiyouga/Baichuan2-7B-Base-LLaMAfied
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
baichuan2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in os.listdir(input_dir):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".bin"):
|
||||
shard_weight = torch.load(os.path.join(input_dir, filepath), map_location="cpu")
|
||||
baichuan2_state_dict.update(shard_weight)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for key, value in baichuan2_state_dict.items():
|
||||
if "W_pack" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("W_pack", "q_proj")] = value[:proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "k_proj")] = value[proj_size:2*proj_size, :]
|
||||
llama2_state_dict[key.replace("W_pack", "v_proj")] = value[2*proj_size:, :]
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = torch.nn.functional.normalize(value)
|
||||
else:
|
||||
llama2_state_dict[key] = value
|
||||
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
||||
for shard_file, shard in shards.items():
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
|
||||
def save_config(
|
||||
input_dir: str,
|
||||
output_dir: str
|
||||
):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
llama2_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict.pop("auto_map", None)
|
||||
llama2_config_dict.pop("tokenizer_class", None)
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_baichuan2(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
save_weight(input_dir, output_dir, shard_size)
|
||||
save_config(input_dir, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_baichuan2)
|
||||
135
tests/llamafy_qwen.py
Normal file
135
tests/llamafy_qwen.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# coding=utf-8
|
||||
# Converts the Qwen models in the same format as LLaMA2.
|
||||
# Usage: python llamafy_qwen.py --input_dir input --output_dir output --shard_size 10GB
|
||||
|
||||
import os
|
||||
import fire
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from safetensors import safe_open
|
||||
from transformers.modeling_utils import shard_checkpoint, WEIGHTS_NAME, WEIGHTS_INDEX_NAME
|
||||
from transformers.utils import check_min_version
|
||||
from typing import Any, Dict
|
||||
|
||||
try:
|
||||
check_min_version("4.34.0")
|
||||
except:
|
||||
raise ValueError("Please upgrade `transformers` to 4.34.0")
|
||||
|
||||
|
||||
CONFIG_NAME = "config.json"
|
||||
|
||||
|
||||
def save_weight(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
) -> str:
|
||||
qwen_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
for filepath in os.listdir(input_dir):
|
||||
if os.path.isfile(os.path.join(input_dir, filepath)) and filepath.endswith(".safetensors"):
|
||||
with safe_open(os.path.join(input_dir, filepath), framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
qwen_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
llama2_state_dict: Dict[str, torch.Tensor] = OrderedDict()
|
||||
torch_dtype = None
|
||||
for key, value in qwen_state_dict.items():
|
||||
if torch_dtype is None:
|
||||
torch_dtype = value.dtype
|
||||
if "wte" in key:
|
||||
llama2_state_dict["model.embed_tokens.weight"] = value
|
||||
elif "ln_f" in key:
|
||||
llama2_state_dict["model.norm.weight"] = value
|
||||
else:
|
||||
key = key.replace("transformer.h", "model.layers")
|
||||
if "attn.c_attn" in key:
|
||||
proj_size = value.size(0) // 3
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.q_proj")] = value[:proj_size, ...]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.k_proj")] = value[proj_size:2*proj_size, ...]
|
||||
llama2_state_dict[key.replace("attn.c_attn", "self_attn.v_proj")] = value[2*proj_size:, ...]
|
||||
elif "attn.c_proj" in key:
|
||||
llama2_state_dict[key.replace("attn.c_proj", "self_attn.o_proj")] = value
|
||||
llama2_state_dict[key.replace("attn.c_proj.weight", "self_attn.o_proj.bias")] = (
|
||||
torch.zeros_like(value[:, 0]).squeeze()
|
||||
)
|
||||
elif "ln_1" in key:
|
||||
llama2_state_dict[key.replace("ln_1", "input_layernorm")] = value
|
||||
elif "ln_2" in key:
|
||||
llama2_state_dict[key.replace("ln_2", "post_attention_layernorm")] = value
|
||||
elif "mlp.w1" in key:
|
||||
llama2_state_dict[key.replace("mlp.w1", "mlp.up_proj")] = value
|
||||
elif "mlp.w2" in key:
|
||||
llama2_state_dict[key.replace("mlp.w2", "mlp.gate_proj")] = value
|
||||
elif "mlp.c_proj" in key:
|
||||
llama2_state_dict[key.replace("mlp.c_proj", "mlp.down_proj")] = value
|
||||
elif "lm_head" in key:
|
||||
llama2_state_dict[key] = value
|
||||
else:
|
||||
raise KeyError("Unable to process key {}".format(key))
|
||||
|
||||
shards, index = shard_checkpoint(llama2_state_dict, max_shard_size=shard_size, weights_name=WEIGHTS_NAME)
|
||||
for shard_file, shard in shards.items():
|
||||
torch.save(shard, os.path.join(output_dir, shard_file))
|
||||
|
||||
if index is None:
|
||||
print("Model weights saved in {}".format(os.path.join(output_dir, WEIGHTS_NAME)))
|
||||
else:
|
||||
with open(os.path.join(output_dir, WEIGHTS_INDEX_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(index, f, indent=2, sort_keys=True)
|
||||
print("Model weights saved in {}".format(output_dir))
|
||||
|
||||
return str(torch_dtype).replace("torch.", "")
|
||||
|
||||
|
||||
def save_config(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
torch_dtype: str
|
||||
):
|
||||
with open(os.path.join(input_dir, CONFIG_NAME), "r", encoding="utf-8") as f:
|
||||
qwen_config_dict: Dict[str, Any] = json.load(f)
|
||||
|
||||
llama2_config_dict: Dict[str, Any] = OrderedDict()
|
||||
llama2_config_dict["architectures"] = ["LlamaForCausalLM"]
|
||||
llama2_config_dict["hidden_act"] = "silu"
|
||||
llama2_config_dict["hidden_size"] = qwen_config_dict["hidden_size"]
|
||||
llama2_config_dict["initializer_range"] = qwen_config_dict["initializer_range"]
|
||||
llama2_config_dict["intermediate_size"] = qwen_config_dict["intermediate_size"] // 2
|
||||
llama2_config_dict["max_position_embeddings"] = qwen_config_dict["max_position_embeddings"]
|
||||
llama2_config_dict["model_type"] = "llama"
|
||||
llama2_config_dict["num_attention_heads"] = qwen_config_dict["num_attention_heads"]
|
||||
llama2_config_dict["num_hidden_layers"] = qwen_config_dict["num_hidden_layers"]
|
||||
llama2_config_dict["num_key_value_heads"] = qwen_config_dict["hidden_size"] // qwen_config_dict["kv_channels"]
|
||||
llama2_config_dict["pretraining_tp"] = 1
|
||||
llama2_config_dict["rms_norm_eps"] = qwen_config_dict["layer_norm_epsilon"]
|
||||
llama2_config_dict["rope_scaling"] = None
|
||||
llama2_config_dict["tie_word_embeddings"] = qwen_config_dict["tie_word_embeddings"]
|
||||
llama2_config_dict["torch_dtype"] = torch_dtype
|
||||
llama2_config_dict["transformers_version"] = "4.34.0"
|
||||
llama2_config_dict["use_cache"] = True
|
||||
llama2_config_dict["vocab_size"] = qwen_config_dict["vocab_size"]
|
||||
llama2_config_dict["attention_bias"] = True
|
||||
|
||||
with open(os.path.join(output_dir, CONFIG_NAME), "w", encoding="utf-8") as f:
|
||||
json.dump(llama2_config_dict, f, indent=2)
|
||||
print("Model config saved in {}".format(os.path.join(output_dir, CONFIG_NAME)))
|
||||
|
||||
|
||||
def llamafy_qwen(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
shard_size: str
|
||||
):
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=False)
|
||||
except Exception as e:
|
||||
raise print("Output dir already exists", e)
|
||||
|
||||
torch_dtype = save_weight(input_dir, output_dir, shard_size)
|
||||
save_config(input_dir, output_dir, torch_dtype)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(llamafy_qwen)
|
||||
@@ -1,743 +0,0 @@
|
||||
# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved.
|
||||
|
||||
import math
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.utils import logging
|
||||
from transformers.generation.utils import GenerationConfig
|
||||
|
||||
from .configuration_baichuan import BaichuanConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
|
||||
def _make_causal_mask(
|
||||
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Make causal mask used for self-attention.
|
||||
"""
|
||||
batch_size, target_length = input_ids_shape
|
||||
mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
|
||||
# ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
|
||||
seq_ids = torch.arange(target_length, device=device)
|
||||
mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask[:, :past_key_values_length] = False
|
||||
|
||||
expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
|
||||
return expanded_mask
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom._expand_mask
|
||||
def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
|
||||
"""
|
||||
Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
|
||||
"""
|
||||
batch_size, src_length = mask.shape
|
||||
tgt_length = tgt_length if tgt_length is not None else src_length
|
||||
|
||||
expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
|
||||
return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
|
||||
|
||||
|
||||
# Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor
|
||||
def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
"""
|
||||
Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
|
||||
relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
|
||||
`softmax(l+a) = softmax(l)`.
|
||||
|
||||
Args:
|
||||
Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
|
||||
attention_mask (`torch.Tensor`):
|
||||
Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
|
||||
num_heads (`int`, *required*):
|
||||
number of heads
|
||||
dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
|
||||
dtype of the output tensor
|
||||
"""
|
||||
batch_size, seq_length = attention_mask.shape
|
||||
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
|
||||
base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.pow(base, powers)
|
||||
|
||||
if closest_power_of_2 != num_heads:
|
||||
extra_base = torch.tensor(
|
||||
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
|
||||
)
|
||||
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
|
||||
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
|
||||
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
|
||||
|
||||
# Note: alibi will added to the attention bias that will be applied to the query, key product of attention
|
||||
# => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
|
||||
# => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
|
||||
# => the query_length dimension will then be broadcasted correctly
|
||||
arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
|
||||
alibi = slopes[..., None] * arange_tensor
|
||||
return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, epsilon=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
|
||||
|
||||
return (self.weight * hidden_states).to(input_dtype)
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.act_fn = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
class BaichuanAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: BaichuanConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.max_position_embeddings = config.model_max_length
|
||||
|
||||
if (self.head_dim * self.num_heads) != self.hidden_size:
|
||||
raise ValueError(
|
||||
f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
|
||||
)
|
||||
|
||||
# Layer-wise attention scaling
|
||||
self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
|
||||
self.beta = 1.0
|
||||
|
||||
self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
|
||||
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
||||
|
||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
|
||||
proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
|
||||
query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
|
||||
query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
|
||||
key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len)
|
||||
value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
past_key, past_value = past_key_value
|
||||
key_states = torch.cat([past_key, key_states], dim=2)
|
||||
value_states = torch.cat([past_value, value_states], dim=1)
|
||||
|
||||
_, _, kv_seq_len = key_states.shape
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# [batch_size * num_heads, q_length, kv_length]
|
||||
# we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
|
||||
matmul_result = alibi.baddbmm(
|
||||
batch1=query_states,
|
||||
batch2=key_states,
|
||||
beta=self.beta,
|
||||
alpha=self.inv_norm_factor,
|
||||
)
|
||||
|
||||
# change view to [batch_size, num_heads, q_length, kv_length]
|
||||
attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len)
|
||||
|
||||
# cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
|
||||
# [batch_size, num_heads, q_length, kv_length]
|
||||
input_dtype = attention_scores.dtype
|
||||
# `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
|
||||
if input_dtype == torch.float16:
|
||||
attention_scores = attention_scores.to(torch.float)
|
||||
attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
|
||||
attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
|
||||
|
||||
# change view [batch_size x num_heads, q_length, kv_length]
|
||||
attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len)
|
||||
|
||||
# matmul: [batch_size * num_heads, q_length, head_dim]
|
||||
attn_output = torch.bmm(attention_probs_reshaped, value_states)
|
||||
|
||||
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attention_probs = None
|
||||
|
||||
return attn_output, attention_probs, past_key_value
|
||||
|
||||
|
||||
class BaichuanLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: BaichuanConfig):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = BaichuanAttention(config=config)
|
||||
self.mlp = MLP(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class BaichuanPreTrainedModel(PreTrainedModel):
|
||||
config_class = BaichuanConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["BaichuanLayer"]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = self.config.initializer_range
|
||||
if isinstance(module, nn.Linear):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BaichuanModel):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_standard_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
|
||||
num_heads, ...]))
|
||||
"""
|
||||
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
num_heads = batch_size_times_num_heads // batch_size
|
||||
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
|
||||
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_baichuan_cache(
|
||||
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
|
||||
"""
|
||||
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
|
||||
batch_size_times_num_heads = batch_size * num_heads
|
||||
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
|
||||
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
|
||||
return tuple(
|
||||
(
|
||||
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
|
||||
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
|
||||
)
|
||||
for layer_past in past_key_value
|
||||
)
|
||||
|
||||
|
||||
class BaichuanModel(BaichuanPreTrainedModel):
|
||||
|
||||
def __init__(self, config: BaichuanConfig):
|
||||
super().__init__(config)
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
self.n_head = config.num_attention_heads
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
||||
self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
|
||||
|
||||
self.gradient_checkpointing = config.gradient_checkpointing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embed_tokens = value
|
||||
|
||||
def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
|
||||
return build_alibi_tensor(attention_mask, num_heads, dtype)
|
||||
|
||||
def _prepare_attn_mask(
|
||||
self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
|
||||
) -> torch.BoolTensor:
|
||||
# create causal mask
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
combined_attention_mask = None
|
||||
device = attention_mask.device
|
||||
_, src_length = input_shape
|
||||
|
||||
if src_length > 1:
|
||||
combined_attention_mask = _make_causal_mask(
|
||||
input_shape, device=device, past_key_values_length=past_key_values_length
|
||||
)
|
||||
|
||||
# [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
|
||||
expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
|
||||
combined_attention_mask = (
|
||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
|
||||
)
|
||||
|
||||
return combined_attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You need to provide input_ids or inputs_embeds")
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[1]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
else:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype)
|
||||
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
alibi,
|
||||
causal_mask,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
alibi=alibi,
|
||||
attention_mask=causal_mask,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.model = BaichuanModel(config)
|
||||
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
**kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
def prepare_inputs_for_generation(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
past_key_values: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
**kwargs
|
||||
) -> dict:
|
||||
if past_key_values:
|
||||
input_ids = input_ids[:, -1:]
|
||||
|
||||
# the cache may be in the standard format (e.g. in contrastive search)
|
||||
if past_key_values[0][0].shape[0] == input_ids.shape[0]:
|
||||
past_key_values = self._convert_to_baichuan_cache(past_key_values)
|
||||
|
||||
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
||||
if inputs_embeds is not None and past_key_values is None:
|
||||
model_inputs = {"inputs_embeds": inputs_embeds}
|
||||
else:
|
||||
model_inputs = {"input_ids": input_ids}
|
||||
|
||||
model_inputs.update(
|
||||
{
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _reorder_cache(
|
||||
self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
|
||||
"""
|
||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
||||
beam_idx at every generation step.
|
||||
|
||||
Output shares the same memory storage as `past`.
|
||||
"""
|
||||
standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
|
||||
|
||||
# Get a copy of `beam_idx` on all the devices where we need those indices.
|
||||
device_to_beam_idx = {
|
||||
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
|
||||
}
|
||||
reordered_past = tuple(
|
||||
(
|
||||
layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
|
||||
)
|
||||
for layer_past in standardized_past
|
||||
)
|
||||
return self._convert_to_baichuan_cache(reordered_past)
|
||||
|
||||
def quantize(self, bits: int):
|
||||
try:
|
||||
from .quantizer import QLinear
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
f"Needs QLinear to run quantize."
|
||||
)
|
||||
|
||||
for layer in self.model.layers:
|
||||
layer.self_attn.W_pack = QLinear(
|
||||
bits=bits,
|
||||
weight=layer.self_attn.W_pack.weight,
|
||||
bias = None,
|
||||
)
|
||||
layer.self_attn.o_proj = QLinear(
|
||||
bits=bits,
|
||||
weight=layer.self_attn.o_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
layer.mlp.gate_proj = QLinear(
|
||||
bits=bits,
|
||||
weight=layer.mlp.gate_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
layer.mlp.down_proj = QLinear(
|
||||
bits=bits,
|
||||
weight=layer.mlp.down_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
layer.mlp.up_proj = QLinear(
|
||||
bits=bits,
|
||||
weight=layer.mlp.up_proj.weight,
|
||||
bias = None,
|
||||
)
|
||||
return self
|
||||
|
||||
def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
|
||||
max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
|
||||
max_input_tokens = self.config.model_max_length - max_new_tokens
|
||||
max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens)
|
||||
total_input, round_input = [], []
|
||||
for i, message in enumerate(messages[::-1]):
|
||||
content_tokens = tokenizer.encode(message['content'])
|
||||
if message['role'] == 'user':
|
||||
round_input = [self.generation_config.user_token_id] + content_tokens + round_input
|
||||
if total_input and len(total_input) + len(round_input) > max_input_tokens:
|
||||
break
|
||||
else:
|
||||
total_input = round_input + total_input
|
||||
if len(total_input) >= max_input_tokens:
|
||||
break
|
||||
else:
|
||||
round_input = []
|
||||
elif message['role'] == 'assistant':
|
||||
round_input = [
|
||||
self.generation_config.assistant_token_id
|
||||
] + content_tokens + [
|
||||
self.generation_config.eos_token_id
|
||||
] + round_input
|
||||
else:
|
||||
raise ValueError(f"message role not supported yet: {message['role']}")
|
||||
total_input = total_input[-max_input_tokens:] # truncate left
|
||||
total_input.append(self.generation_config.assistant_token_id)
|
||||
total_input = torch.LongTensor([total_input]).to(self.device)
|
||||
return total_input
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(self, tokenizer, messages: List[dict], stream=False,
|
||||
generation_config: Optional[GenerationConfig]=None):
|
||||
generation_config = generation_config or self.generation_config
|
||||
input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens)
|
||||
if stream:
|
||||
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
|
||||
self.__class__.generate = NewGenerationMixin.generate
|
||||
self.__class__.sample_stream = NewGenerationMixin.sample_stream
|
||||
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
|
||||
|
||||
def stream_generator():
|
||||
outputs = []
|
||||
for token in self.generate(input_ids, generation_config=stream_config):
|
||||
outputs.append(token.item())
|
||||
yield tokenizer.decode(outputs, skip_special_tokens=True)
|
||||
|
||||
return stream_generator()
|
||||
else:
|
||||
self.__class__.generate = PreTrainedModel.generate # disable stream
|
||||
outputs = self.generate(input_ids, generation_config=generation_config)
|
||||
response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True)
|
||||
return response
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user