96 Commits

Author SHA1 Message Date
hiyouga
6eed1db36c Release v0.1.7
Former-commit-id: 81abe8d6cabaa1ebe74dc32a5dc143389e4c9f31
2023-08-18 17:21:27 +08:00
hiyouga
948124f55e tiny fix
Former-commit-id: 0ee159654ac6339c162745b004e2152ba6fe3c81
2023-08-18 13:07:35 +08:00
hiyouga
2b191ca776 support ppo score norm (trl 0.5.1.dev required)
Former-commit-id: 2b25db6d260ec1532281a592e873579346c7d21c
2023-08-18 12:02:42 +08:00
hiyouga
be4d2822ea fix PPO trainer #551 , update readme
Former-commit-id: faead74849470cebae9e37cde5fab2a71b32aa43
2023-08-18 11:43:10 +08:00
hiyouga
736ddd0319 update readme
Former-commit-id: beaf2fb737dbe64d35334d88b42935c89ef09eee
2023-08-18 01:51:55 +08:00
hiyouga
dfa289aa72 Update .gitignore
Former-commit-id: a1772a4dfef8dfaf7c2c321fad0a70ccf95fe6a0
2023-08-18 01:43:42 +08:00
hiyouga
c2644f939a update training resuming
Former-commit-id: 2ec75c31f609e65116ac3b621eeb7d8ccbf69135
2023-08-18 01:41:17 +08:00
hoshi-hiyouga
f11c1ae562 Merge pull request #434 from niuba/main
add last_checkpoint support

Former-commit-id: b78d461f2826c194c332ead37825704c2cb8b910
2023-08-18 01:38:31 +08:00
hoshi-hiyouga
3126164aa6 Merge branch 'main' into main
Former-commit-id: 870d2c7bf74d0da5a927bef4b8b01d15cc66a3e9
2023-08-18 01:37:23 +08:00
hiyouga
ed10486cad support bf16 ppo #551
Former-commit-id: 092088967de7409a2d51847cfc7afc83a8887320
2023-08-18 00:40:32 +08:00
hiyouga
04fa430c6c fix ChatGLM2 ppo #527 #528
Former-commit-id: 60d6ad64d7c9f6445b0df8de0153c3a311974198
2023-08-18 00:34:59 +08:00
hiyouga
fa1893b59c fix generation bug #532
Former-commit-id: c071121e67374e5f09798db57cfc8668617a36ae
2023-08-17 22:21:34 +08:00
hiyouga
e993e717a5 fix streaming in pt stage #548 #549
Former-commit-id: 050e992bee2a9293cc7399b578de807b5bf9bddc
2023-08-17 17:59:26 +08:00
hiyouga
c80e56423a update readme
Former-commit-id: b74af3c9cf29e1690ae4d5acb27599b1abd152e2
2023-08-17 11:00:22 +08:00
hiyouga
ffa09a01d6 fix baichuan and intern template
Former-commit-id: e1fd18fa6ef1009f978aca5210a259251a0b19a6
2023-08-17 01:27:20 +08:00
hiyouga
7d04f8567b fix generation
Former-commit-id: 66a0300d312ef91c24fcf80667fa3b0bb8e1a342
2023-08-16 22:39:54 +08:00
hiyouga
baa709674f fix system prompt
Former-commit-id: 411e775aa939bdd154a3f1e92921ede90d989f18
2023-08-16 01:35:52 +08:00
hiyouga
ca9a494d0c fix baichuan template #481
Former-commit-id: 7608c6c25877d97ef26a1c209c4073c9c42f4535
2023-08-15 11:38:21 +08:00
hoshi-hiyouga
37eb8c05cc Merge pull request #516 from liuyanyi/add_gitignore
[Enhance] Add .gitignore file

Former-commit-id: 12cfe5482f5ef95d8c386d0af0de381e72eab0f9
2023-08-15 11:25:40 +08:00
hiyouga
7c046edb7b fix ChatGLM RLHF
Former-commit-id: 4e43e887e432ceb7e9287b4e309b63af3c3ba1bf
2023-08-15 11:19:20 +08:00
Yanyi Liu
22cea38b20 Add .gitignore
Former-commit-id: a2ebdeef81706596617da4409fc5da71739bccdc
2023-08-15 11:13:45 +08:00
hiyouga
ef2ca0a827 alert pad_token source
Former-commit-id: f26a84e0d927d2554890daf431a93652e18f4235
2023-08-15 00:07:56 +08:00
hiyouga
7f0b908de2 update webui
Former-commit-id: da30d0fb4abdb825f3383ddd106bb06a84695b7a
2023-08-14 22:45:26 +08:00
hoshi-hiyouga
5fc5e776ff Merge pull request #511 from hiyouga/feature-autoTemplate
add template match and stage in webui

Former-commit-id: 413752ecba845cddaff5fb48db7d3d24b960eec1
2023-08-14 22:44:04 +08:00
codemayq
93b281c016 auto match template when change model_name
Former-commit-id: ab2d7ab0572765ce33a52ac71641062d5d904db4
2023-08-14 20:56:05 +08:00
codemayq
9585699918 add template match and stage in webui
Former-commit-id: d6283e7f041f08f76d18350cb5f6a6c58ca80e92
2023-08-14 20:42:59 +08:00
hiyouga
bceaba551d fix ChatGLM lm_head #494
Former-commit-id: bf0048abdaeb2b9592d38ac991704ad014370b47
2023-08-14 14:14:48 +08:00
hiyouga
0bfeed3a7e fix bug in webui
Former-commit-id: c95f0f687689934379b6c24abf872ffcde06073b
2023-08-14 11:38:42 +08:00
hiyouga
70a780c3c0 fix webui cache
Former-commit-id: 9aba5c197fbc8abaab77f454374f8b497f0310d0
2023-08-14 11:37:01 +08:00
hiyouga
d74ab5306c update readme_zh
Former-commit-id: bdfe7e0285fdeb3a2728669dbdabf70c9652735c
2023-08-14 11:13:25 +08:00
hiyouga
688e8601ab web UI integrating RLHF
Former-commit-id: 137fd146b90f89a1164b56e6d507b30b1f5c2437
2023-08-14 10:48:47 +08:00
hiyouga
4933ab5956 fix #480
Former-commit-id: ec15ca8fffacba2c34e1849c5ce90ca9989d66a2
2023-08-14 00:23:56 +08:00
hiyouga
6c7225a5d4 fix webui
Former-commit-id: 2c8b7414be9b43e20cc1d0575cc4dc1c7545fd86
2023-08-12 23:52:07 +08:00
hiyouga
a22982f2fa tiny fix
Former-commit-id: 50a34c043de6d9e1410291e1d8c1ea9d53754e9e
2023-08-12 22:02:43 +08:00
hiyouga
c95479dddb fix rope scaling
Former-commit-id: 2e0dd36700ec5e8294581c1db4b9431f755fc5f8
2023-08-12 22:00:01 +08:00
hiyouga
fc48bd8da0 update readme
Former-commit-id: 94ac570cb62aa9cd5dba105f0bb4c4da43eca042
2023-08-12 21:29:06 +08:00
hiyouga
d5323bfa3f update readme
Former-commit-id: ecfe87f34b383901f8e97ffb90af459cd55419b1
2023-08-12 21:25:19 +08:00
hiyouga
e9d4a2b507 update readme
Former-commit-id: eadbe9b7a0b6c8897e7a763b519cc5b7e00f3b2c
2023-08-12 21:23:05 +08:00
hiyouga
37bcbe8046 update readme
Former-commit-id: 6fa381400c21fa249cebcdff8c3afd72f8de20b3
2023-08-12 21:00:11 +08:00
hiyouga
fdfb644f0a support rope scaling, fix #475 #476 #478
Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
2023-08-12 20:46:27 +08:00
hoshi-hiyouga
cde9f3db57 Merge pull request #479 from hiyouga/feature-addCmdExport
add sft script preview in webui

Former-commit-id: 060225e57d13d8164beb6920410c181fbb28b77a
2023-08-12 20:41:52 +08:00
codemayq
8bf5a98815 add sft script preview in webui
Former-commit-id: 2b72649b404750226aa418b61ef5a6c9ac03938f
2023-08-12 13:53:55 +08:00
hiyouga
be566a15a5 fix unusual output of 8bit models #278 #391
Former-commit-id: 337ce5272b81f5561162beb08814b0e5abf23703
2023-08-12 00:25:29 +08:00
hiyouga
d5f1b99ac4 Release v0.1.6
Former-commit-id: 43c8b3c3c8bfb2e32d17fb3e8b194938e37d54bd
2023-08-11 23:25:57 +08:00
hiyouga
2144bb0e27 Update README_zh.md
Former-commit-id: 4fc154bcf039ba3f9240213158df757881cf3579
2023-08-11 14:06:02 +08:00
hiyouga
bc665bacc7 add defaults
Former-commit-id: 4636d3bbe6b984ca93e3a80ae5239f3ddda461bd
2023-08-11 13:56:26 +08:00
hiyouga
52bfcf4883 fix stop word in baichuan template
Former-commit-id: cba5ac9cfc81f11b97831998ea15def5e0b487c2
2023-08-11 13:51:46 +08:00
hiyouga
06df3d6fb6 fix baichuan template
Former-commit-id: b1681fe35346381cda613297f1cbb710f0a6daa6
2023-08-11 13:45:47 +08:00
hiyouga
ca719a8697 support DPO training (2305.18290)
Former-commit-id: 6d98de148e4af63a7028dfaeb6cf86eb56a4488f
2023-08-11 03:02:53 +08:00
hoshi-hiyouga
72dfd74005 Merge pull request #451 from jovialchen/main
huggingface login for projects must login while running

Former-commit-id: 246ac241277908909b81cdf85fec1f24449dbae9
2023-08-10 17:25:38 +08:00
hiyouga
69302c4420 fix webui val size
Former-commit-id: 490c067d4e0828832e0ebdb704a9207dc974b15b
2023-08-10 15:20:44 +08:00
jiongxuc
42d7019b2e huggingface login for projects must login while running
Former-commit-id: 0a4a2a1d3e0ff1f57215512d294d782080bd383c
2023-08-10 14:57:12 +08:00
hiyouga
5f0d0d6b9b fix template
Former-commit-id: e3967eb1cdd8d19e8afee9ba52e7eb7d6cd86129
2023-08-09 23:14:27 +08:00
hiyouga
76cb63e4f6 fix template
Former-commit-id: 907e8cd86fbd4cdfa26dad21ceaf6e01d8fe37e4
2023-08-09 23:10:20 +08:00
hiyouga
467d571206 support val set in streaming mode
Former-commit-id: faed15b58ed00b1e09bb091e7eee48f5ef7c508b
2023-08-09 23:00:26 +08:00
hiyouga
972bfa700a fix tokenizer
Former-commit-id: 7849587cd4e149291d08edef9a528a1bad796c7e
2023-08-09 17:52:15 +08:00
niuba
458955d0fb add last_checkpoint support
Former-commit-id: 9f1977e4de00b14a9d1b555c25bcaf12998d5046
2023-08-09 16:39:27 +08:00
hiyouga
990eeccf45 fix sft trainer
Former-commit-id: 08cc888b1569572d0cd20bcf3f07e20072a0311a
2023-08-09 16:35:03 +08:00
hiyouga
a3a7465f00 fix rm #420, fix template #426, fix #423
Former-commit-id: 70ea3caaa7a7695c77179cd1bb18707a80a373d7
2023-08-09 16:23:31 +08:00
hoshi-hiyouga
031a819257 fix llama2 template
Former-commit-id: 6c74f726d4e672f5a1a57df201c27c1f697384f0
2023-08-09 00:58:27 +08:00
hoshi-hiyouga
eb4b4e3c8c fix tokenizer
Former-commit-id: fa463ef279b596d5d53cc169831f51b42031fc05
2023-08-09 00:54:54 +08:00
hiyouga
d2e1fe9b1d update webui
Former-commit-id: 343a4cd82b07a40f96ba413d1d991419ff07a24a
2023-08-09 00:26:11 +08:00
hiyouga
6e27a9e39a fix tokenizer #417
Former-commit-id: 01aa678311bfd213a4b410a4e0ff09f48a0d40a1
2023-08-08 23:59:41 +08:00
hiyouga
805478c911 fix bug
Former-commit-id: 0dff1d951f1a9fe05a74d334bf477b55c7c64199
2023-08-08 21:28:28 +08:00
hiyouga
a281cdeb89 fix bug
Former-commit-id: c13ce66021b21e015871b84489eeafa127a424a4
2023-08-08 17:55:55 +08:00
hiyouga
cda698a67f fix chatml template #408
Former-commit-id: 21e0cc3f44c35ae689b00b274391492f413725ac
2023-08-08 17:44:39 +08:00
hiyouga
15acd17716 update args spec
Former-commit-id: a006068346edda6e2851b23d2005fdb218a7287d
2023-08-07 15:23:35 +08:00
hiyouga
34a2bddfcd update readme
Former-commit-id: 06bcbb901f69265632892a5fcbc956b8be1153da
2023-08-07 15:02:02 +08:00
hiyouga
370f817549 Merge branch 'main' of https://github.com/hiyouga/LLaMA-Efficient-Tuning
Former-commit-id: 5c5657227db285048e3850631badb040eea9b6ca
2023-08-07 13:59:16 +08:00
hiyouga
041390c37e fix #376
Former-commit-id: a5b01257ba3323bcb2dd0217fb89a387e39ddbec
2023-08-07 13:58:59 +08:00
hoshi-hiyouga
d9fe4bf500 Merge pull request #382 from hiyouga/feature-updateReadme
add detailed model configs

Former-commit-id: 371c50cf3fd4e3f5e8fb390508c27cb5f18fa531
2023-08-07 13:43:38 +08:00
hiyouga
e0c7e944fc update trainer
Former-commit-id: 0d39b53a5164e34d22fe0a492eaa0d7ac63102fe
2023-08-07 13:34:35 +08:00
codemayq
0845fe67db add detailed model configs
Former-commit-id: 438c43f820e39738eaa1c296aadcf6d141c3289f
2023-08-07 09:30:23 +08:00
hiyouga
fe3b12d900 fix qwen eos token
Former-commit-id: 770830c67886f5872b39b9608949ec62d4616b27
2023-08-06 13:31:17 +08:00
hiyouga
a70d56864e fix qwen tokenizer #361
Former-commit-id: 78a2fa95c8ab669254a6c8fce8138c4395fb0a09
2023-08-05 17:06:05 +08:00
hiyouga
fdbb2c5378 fix template for tiktoken
Former-commit-id: 8328447f81eb5b90310df08cf2928c83ef6355fe
2023-08-05 13:42:42 +08:00
hiyouga
3c0aaf42af remove redundant code
Former-commit-id: dcec1717592107ba9d26eb2ac520309da19d1805
2023-08-05 00:27:27 +08:00
hiyouga
438e19160a fix template
Former-commit-id: b88200a88ea112e043dc44058606805c60e32844
2023-08-05 00:25:00 +08:00
hiyouga
f2b2ff6950 fix llama2 template
Former-commit-id: 08f37145e0bca5f1a8fd7bad01c64dc69b07361b
2023-08-05 00:07:54 +08:00
hoshi-hiyouga
86cef96305 Support safe ChatML template, fix qwen tok #351 #354
https://github.com/openai/openai-python/blob/main/chatml.md
Former-commit-id: 94bfc9d85f7cef3a5eb15085e0124a424373814f
2023-08-05 00:00:23 +08:00
hiyouga
5f50944baf fix bos and eos token
Former-commit-id: ab386f4c0fb5eaac24264a5bbef4c03deeb92158
2023-08-04 23:55:57 +08:00
hiyouga
0804fd2353 fix encode
Former-commit-id: ec382abd906d93cf78c7fbaec753ce6bcf8cfebd
2023-08-04 23:27:55 +08:00
hiyouga
86419eb457 support chatml safe encoding
Former-commit-id: ea52bb135bf9d07738091006ec7ada8df14cf15e
2023-08-04 23:14:28 +08:00
hiyouga
76f3ae7bf3 support interleave probs
Former-commit-id: 168d99816f9bdc746c587f7f09753ba7e0a4b19d
2023-08-04 21:27:35 +08:00
hiyouga
aaa85190eb fix webui export model
Former-commit-id: c34469c05e681239db23e2e666b5ac6a4e38aba9
2023-08-04 14:20:27 +08:00
hiyouga
e2a4e926b9 fix mtloader
Former-commit-id: ca48c2c02c3cfa9afb99971b50daeda9cf14e7cb
2023-08-03 19:29:02 +08:00
hiyouga
d6e922dc1c tiny fix
Former-commit-id: 81ef7017a4c96441951adeff0276cc5ab76a3544
2023-08-03 17:42:28 +08:00
hiyouga
27f4317ec6 fix qwen inference
Former-commit-id: 823f0de0ca0a92b6f48a90e5ffe57a48dc018f1d
2023-08-03 16:31:55 +08:00
hiyouga
e434348216 fix qwen inference
Former-commit-id: 2c5fe45ce1405124f12ecd20e263b5538af97972
2023-08-03 16:15:38 +08:00
hiyouga
2e19afedb8 support Qwen-7B, fix InternLM-7B inference
Former-commit-id: 25d2ca29ecb70cbfd5206333c667042a0c4d2e5a
2023-08-03 15:53:32 +08:00
hiyouga
da08fa7c63 update web demo
Former-commit-id: 5b6ad9adb665096bfb36dc90789a1d4a16345122
2023-08-03 13:28:28 +08:00
hiyouga
9c96b97dc7 fix webui
Former-commit-id: e87630ef77977b2879f1199b9a421acbbbb32a51
2023-08-03 12:43:12 +08:00
hiyouga
28a51b622b modify code structure
Former-commit-id: 6369f9b1751e6f9bb709ba76a85f69cbe0823e5d
2023-08-02 23:17:36 +08:00
hiyouga
8bd1da7144 fix PPO trainer
Former-commit-id: 21982a7d4dd9b7c3a1145b481f02b9990e32dc00
2023-08-02 19:10:23 +08:00
hiyouga
e4d0b8ee6e update ppo trainer
Former-commit-id: c27136a83e167465d3f825e40f10c7b9fcfbf97a
2023-08-02 18:46:41 +08:00
hiyouga
1dfb28b362 fix memory leak of PPO trainer
Former-commit-id: 38410894a5ebf0b043b55a6bd5cca3cd0a44b27d
2023-08-02 17:41:34 +08:00
65 changed files with 2397 additions and 1035 deletions

160
.gitignore vendored Normal file
View File

@@ -0,0 +1,160 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

255
README.md
View File

@@ -12,17 +12,25 @@
## Changelog ## Changelog
[23/07/31] Now we support dataset streaming. Try `--streaming` and `--max_steps 100` arguments to stream your dataset. [23/08/18] Now we support **resuming training**, upgrade `transformers` to `4.31.0` to enjoy this feature.
[23/08/12] Now we support **RoPE scaling** to extend the context length of the LLaMA models. Try `--rope_scaling linear` argument in training and `--rope_scaling dynamic` argument at inference to extrapolate the position embeddings.
[23/08/11] Now we support **[DPO training](https://arxiv.org/abs/2305.18290)** for instruction-tuned models. See [this example](#dpo-training) to train your models (experimental feature).
[23/08/03] Now we support training the **Qwen-7B** model in this repo. Try `--model_name_or_path Qwen/Qwen-7B-Chat` and `--lora_target c_attn` arguments to train the Qwen-7B model. Remember to use `--template chatml` argument when you are using the Qwen-7B-Chat model.
[23/07/31] Now we support **dataset streaming**. Try `--streaming` and `--max_steps 10000` arguments to load your dataset in streaming mode.
[23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details. [23/07/29] We release two instruction-tuned 13B models at Hugging Face. See these Hugging Face Repos ([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft)) for details.
[23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model. [23/07/19] Now we support training the **LLaMA-2** models in this repo. Try `--model_name_or_path meta-llama/Llama-2-7b-hf` argument to use the LLaMA-2 model. Remember to use `--template llama2` argument when you are using the LLaMA-2-chat model.
[23/07/18] Now we develop an all-in-one Web UI for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development. [23/07/18] Now we develop an **all-in-one Web UI** for training, evaluation and inference. Try `train_web.py` to fine-tune models in your Web browser. Thank [@KanadeSiina](https://github.com/KanadeSiina) and [@codemayq](https://github.com/codemayq) for their efforts in the development.
[23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model. [23/07/11] Now we support training the **Baichuan-13B** model in this repo. Try `--model_name_or_path baichuan-inc/Baichuan-13B-Base` and `--lora_target W_pack` arguments to train the Baichuan-13B model. Remember to use `--template baichuan` argument when you are using the Baichuan-13B-Chat model.
[23/07/09] Now we release [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested. [23/07/09] Now we release **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹, an easy-to-use package for editing the factual knowledge of large language models efficiently. Please follow [FastEdit](https://github.com/hiyouga/FastEdit) if you are interested.
[23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model. [23/07/07] Now we support training the **InternLM-7B** model in this repo. Try `--model_name_or_path internlm/internlm-7b` argument to use the InternLM model. Remember to use `--template intern` argument when you are using the InternLM-chat model.
@@ -40,28 +48,33 @@
## Supported Models ## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) | Model | Model size | Default module | Template |
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B) | -------------------------------------------------------- | --------------------------- | ----------------- |----------|
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [InternLM](https://github.com/InternLM/InternLM) (7B) | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
- **Default module** is used for the `--lora_target` argument. Please use `python src/train_bash.py -h` to see all available options.
- For the "base" models, the `--template` argument can be chosen from `default`, `alpaca`, `vicuna` etc. But make sure to use the corresponding template for the "chat" models.
## Supported Training Approaches ## Supported Training Approaches
- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) | Approach | Full-parameter | Partial-parameter | LoRA | QLoRA |
- Full-parameter tuning | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- Partial-parameter tuning | Pre-Training | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | Supervised Fine-Tuning | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314) | Reward Modeling | | | :white_check_mark: | :white_check_mark: |
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) | PPO Training | | | :white_check_mark: | :white_check_mark: |
- Full-parameter tuning | DPO Training | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- Partial-parameter tuning
- [LoRA](https://arxiv.org/abs/2106.09685) - Use `--quantization_bit 4/8` argument to enable QLoRA.
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets ## Provided Datasets
@@ -78,7 +91,6 @@
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -93,7 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- For reward modelling: - For reward modeling or DPO training:
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -111,6 +123,7 @@ huggingface-cli login
- Python 3.8+ and PyTorch 1.13.1+ - Python 3.8+ and PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT and TRL - 🤗Transformers, Datasets, Accelerate, PEFT and TRL
- sentencepiece and tiktoken
- jieba, rouge-chinese and nltk (used at evaluation) - jieba, rouge-chinese and nltk (used at evaluation)
- gradio and matplotlib (used in web_demo.py) - gradio and matplotlib (used in web_demo.py)
- uvicorn, fastapi and sse-starlette (used in api_demo.py) - uvicorn, fastapi and sse-starlette (used in api_demo.py)
@@ -128,7 +141,6 @@ Note: please update `data/dataset_info.json` to use your custom dataset. About t
### Dependence Installation (optional) ### Dependence Installation (optional)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@@ -148,18 +160,23 @@ pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/downl
CUDA_VISIBLE_DEVICES=0 python src/train_web.py CUDA_VISIBLE_DEVICES=0 python src/train_web.py
``` ```
We strongly recommend using the all-in-one Web UI for newcomers since it can also generate training scripts **automatically**.
Currently the web UI only supports training on **a single GPU**. Currently the web UI only supports training on **a single GPU**.
### (Continually) Pre-Training ### Train on a single GPU
#### Pre-Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -173,16 +190,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### Supervised Fine-Tuning #### Supervised Fine-Tuning
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -196,42 +214,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
Remember to specify `--lora_target W_pack` if you are using Baichuan models. #### Reward Modeling
### Reward Model Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset comparison_gpt4_en \ --dataset comparison_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--save_steps 1000 \ --save_steps 1000 \
--learning_rate 1e-5 \ --learning_rate 1e-6 \
--num_train_epochs 1.0 \ --num_train_epochs 1.0 \
--plot_loss \ --plot_loss \
--fp16 --fp16
``` ```
### PPO Training (RLHF) #### PPO Training
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \ --stage ppo \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \ --reward_model path_to_rm_checkpoint \
@@ -243,17 +261,45 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--save_steps 1000 \ --save_steps 1000 \
--learning_rate 1e-5 \ --learning_rate 1e-5 \
--num_train_epochs 1.0 \ --num_train_epochs 1.0 \
--plot_loss --plot_loss \
--fp16
```
#### DPO Training
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_en \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
``` ```
### Distributed Training ### Distributed Training
#### Use Huggingface Accelerate
```bash ```bash
accelerate config # configure the environment accelerate config # configure the environment
accelerate launch src/train_bash.py # arguments (same as above) accelerate launch src/train_bash.py # arguments (same as above)
``` ```
<details><summary>Example configuration for full-tuning with DeepSpeed ZeRO-2</summary> <details><summary>Example config.yaml for training with DeepSpeed ZeRO-2</summary>
```yaml ```yaml
compute_environment: LOCAL_MACHINE compute_environment: LOCAL_MACHINE
@@ -281,12 +327,93 @@ use_cpu: false
</details> </details>
#### Use DeepSpeed
```bash
deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--deepspeed ds_config.json \
... # arguments (same as above)
```
<details><summary>Example ds_config.json for training with DeepSpeed ZeRO-2</summary>
```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### Export model
```bash
python src/export_model.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_llama_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Evaluation (BLEU and ROUGE_CHINESE) ### Evaluation (BLEU and ROUGE_CHINESE)
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_eval \ --do_eval \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
@@ -305,7 +432,7 @@ We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_predict \ --do_predict \
--dataset alpaca_gpt4_en \ --dataset alpaca_gpt4_en \
--template default \ --template default \
@@ -317,49 +444,6 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--predict_with_generate --predict_with_generate
``` ```
### API Demo
```bash
python src/api_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
Visit `http://localhost:8000/docs` for API documentation.
### CLI Demo
```bash
python src/cli_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Web Demo
```bash
python src/web_demo.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint
```
### Export model
```bash
python src/export_model.py \
--model_name_or_path path_to_your_model \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_export
```
## TODO ## TODO
- [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)). - [ ] Supporting flash attention ([torch](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) / [xformers](https://github.com/facebookresearch/xformers) / [flashattn](https://github.com/Dao-AILab/flash-attention)).
@@ -378,6 +462,9 @@ Please follow the model licenses to use the corresponding model weights:
- [Falcon](LICENSE) - [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## Citation ## Citation

View File

@@ -12,73 +12,85 @@
## 更新日志 ## 更新日志
[23/07/31] 现在我们支持了训练数据流式加载。请尝试使用 `--streaming``--max_steps 100` 参数来流式加载数据集 [23/08/18] 现在我们支持了**训练状态恢复**,请将 `transformers` 升级至 `4.31.0` 以启用此功能
[23/08/12] 现在我们支持了 **RoPE 插值**来扩展 LLaMA 模型的上下文长度。请尝试使用 `--rope_scaling linear` 参数训练模型或使用 `--rope_scaling dynamic` 参数评估模型。
[23/08/11] 现在我们支持了指令模型的 **[DPO 训练](https://arxiv.org/abs/2305.18290)**。详情请参阅[此示例](#dpo-训练)(实验性功能)。
[23/08/03] 现在我们支持了 **Qwen-7B** 模型的训练。请尝试使用 `--model_name_or_path Qwen/Qwen-7B-Chat``--lora_target c_attn` 参数。使用 Qwen-7B-Chat 模型时请添加 `--template chatml` 参数。
[23/07/31] 现在我们支持了**数据流式加载**。请尝试使用 `--streaming``--max_steps 10000` 参数来流式加载数据集。
[23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。 [23/07/29] 我们在 Hugging Face 发布了两个 13B 指令微调模型。详细内容请查阅我们的 Hugging Face 项目([LLaMA-2](https://huggingface.co/hiyouga/Llama-2-Chinese-13b-chat) / [Baichuan](https://huggingface.co/hiyouga/baichuan-13b-sft))。
[23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。请注意使用 LLaMA-2-chat 模型需要添加 `--template llama2` 参数。 [23/07/19] 现在我们支持了 **LLaMA-2** 模型的训练。请尝试使用 `--model_name_or_path meta-llama/Llama-2-7b-hf` 参数。使用 LLaMA-2-chat 模型时请添加 `--template llama2` 参数。
[23/07/18] 我们开发了支持训练和测试的浏览器一键微调界面。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。 [23/07/18] 我们开发了支持训练和测试的**浏览器一体化界面**。请尝试使用 `train_web.py` 在您的浏览器中微调模型。感谢 [@KanadeSiina](https://github.com/KanadeSiina) 和 [@codemayq](https://github.com/codemayq) 在该功能开发中付出的努力。
[23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path path_to_baichuan_model``--lora_target W_pack` 参数。请注意使用 Baichuan-13B-Chat 模型需要添加 `--template baichuan` 参数。 [23/07/11] 现在我们支持了 **Baichuan-13B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-13B-Base``--lora_target W_pack` 参数。使用 Baichuan-13B-Chat 模型时请添加 `--template baichuan` 参数。
[23/07/09] 我们开源了 [FastEdit](https://github.com/hiyouga/FastEdit)⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。 [23/07/09] 我们开源了 **[FastEdit](https://github.com/hiyouga/FastEdit)** ⚡🩹,一个简单易用的、能迅速编辑大模型事实记忆的工具包。如果您感兴趣请关注我们的 [FastEdit](https://github.com/hiyouga/FastEdit) 项目。
[23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。请注意使用 InternLM-chat 模型需要添加 `--template intern` 参数。 [23/07/07] 现在我们支持了 **InternLM-7B** 模型的训练。请尝试使用 `--model_name_or_path internlm/internlm-7b` 参数。使用 InternLM-chat 模型时请添加 `--template intern` 参数。
[23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。 [23/07/05] 现在我们支持了 **Falcon-7B/40B** 模型的训练。请尝试使用 `--model_name_or_path tiiuae/falcon-7b``--lora_target query_key_value` 参数。
[23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。 [23/06/29] 我们提供了一个**可复现的**指令模型微调示例,详细内容请查阅 [Hugging Face 项目](https://huggingface.co/hiyouga/baichuan-7b-sft)。
[23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入任意基于 ChatGPT 的应用中。 [23/06/22] 我们对齐了[示例 API](src/api_demo.py) 与 [OpenAI API](https://platform.openai.com/docs/api-reference/chat) 的格式,您可以将微调模型接入**任意基于 ChatGPT 的应用**中。
[23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。 [23/06/15] 现在我们支持了 **Baichuan-7B** 模型的训练。请尝试使用 `--model_name_or_path baichuan-inc/Baichuan-7B``--lora_target W_pack` 参数。
[23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 [QLoRA](https://github.com/artidoro/qlora))。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。 [23/06/03] 现在我们实现了 4 比特的 LoRA 训练(也称 **[QLoRA](https://github.com/artidoro/qlora)**)。请尝试使用 `--quantization_bit 4` 参数进行 4 比特量化微调。
[23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。 [23/05/31] 现在我们支持了 **BLOOM & BLOOMZ** 模型的训练。请尝试使用 `--model_name_or_path bigscience/bloomz-7b1-mt``--lora_target query_key_value` 参数。
## 模型 ## 模型
- [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) | 模型名 | 模型大小 | 默认模块 | Template |
- [LLaMA-2](https://huggingface.co/meta-llama) (7B/13B/70B) | -------------------------------------------------------- | --------------------------- | ----------------- |----------|
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) | [LLaMA](https://github.com/facebookresearch/llama) | 7B/13B/33B/65B | q_proj,v_proj | - |
- [Falcon](https://huggingface.co/tiiuae/falcon-7b) (7B/40B) | [LLaMA-2](https://huggingface.co/meta-llama) | 7B/13B/70B | q_proj,v_proj | llama2 |
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B/13B) | [BLOOM](https://huggingface.co/bigscience/bloom) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
- [InternLM](https://github.com/InternLM/InternLM) (7B) | [BLOOMZ](https://huggingface.co/bigscience/bloomz) | 560M/1.1B/1.7B/3B/7.1B/176B | query_key_value | - |
| [Falcon](https://huggingface.co/tiiuae/falcon-7b) | 7B/40B | query_key_value | - |
| [Baichuan](https://github.com/baichuan-inc/baichuan-13B) | 7B/13B | W_pack | baichuan |
| [InternLM](https://github.com/InternLM/InternLM) | 7B | q_proj,v_proj | intern |
| [Qwen](https://github.com/QwenLM/Qwen-7B) | 7B | c_attn | chatml |
| [XVERSE](https://github.com/xverse-ai/XVERSE-13B) | 13B | q_proj,v_proj | - |
| [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B) | 6B | query_key_value | chatglm2 |
## 微调方法 - **默认模块**是 `--lora_target` 参数的部分可选项。请使用 `python src/train_bash.py -h` 查看全部可选项。
- 对于所有“基座”Base模型`--template` 参数可以是 `default`, `alpaca`, `vicuna` 等任意值。但“对话”Chat模型请务必使用对应的模板。
- [二次预训练](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) ## 训练方法
- 全参数微调
- 部分参数微调 | 方法 | 全参数训练 | 部分参数训练 | LoRA | QLoRA |
- [LoRA](https://arxiv.org/abs/2106.09685) | ---------------------- | ------------------ | ------------------ | ------------------ | ------------------ |
- [QLoRA](https://arxiv.org/abs/2305.14314) | 预训练 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- [指令监督微调](https://arxiv.org/abs/2109.01652) | 指令监督微调 | :white_check_mark: | :white_check_mark: | :white_check_mark: | :white_check_mark: |
- 全参数微调 | 奖励模型训练 | | | :white_check_mark: | :white_check_mark: |
- 部分参数微调 | PPO 训练 | | | :white_check_mark: | :white_check_mark: |
- [LoRA](https://arxiv.org/abs/2106.09685) | DPO 训练 | :white_check_mark: | | :white_check_mark: | :white_check_mark: |
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [人类反馈的强化学习RLHF](https://arxiv.org/abs/2203.02155) - 使用 `--quantization_bit 4/8` 参数来启用 QLoRA 训练。
- [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## 数据集 ## 数据集
- 用于二次预训练: - 用于预训练
- [Wiki Demo (en)](data/wiki_demo.txt) - [Wiki Demo (en)](data/wiki_demo.txt)
- [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb) - [RefinedWeb (en)](https://huggingface.co/datasets/tiiuae/falcon-refinedweb)
- [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata) - [StarCoder (en)](https://huggingface.co/datasets/bigcode/starcoderdata)
- [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220) - [Wikipedia (en)](https://huggingface.co/datasets/olm/olm-wikipedia-20221220)
- [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered) - [Wikipedia (zh)](https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered)
- 用于指令监督微调: - 用于指令监督微调
- [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca) - [Stanford Alpaca (en)](https://github.com/tatsu-lab/stanford_alpaca)
- [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) - [Stanford Alpaca (zh)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [Self-cognition (zh)](data/self_cognition.json) - [Self-cognition (zh)](data/self_cognition.json)
- [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection) - [ShareGPT (zh)](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT/tree/main/Chinese-instruction-collection)
- [RefGPT (zh)](https://github.com/sufengniu/RefGPT)
- [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) - [Guanaco Dataset (multilingual)](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
- [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN) - [BELLE 2M (zh)](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
- [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN) - [BELLE 1M (zh)](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
@@ -93,7 +105,7 @@
- [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa) - [Web QA (zh)](https://huggingface.co/datasets/suolyer/webqa)
- [UltraChat (en)](https://github.com/thunlp/UltraChat) - [UltraChat (en)](https://github.com/thunlp/UltraChat)
- [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn) - [WebNovel (zh)](https://huggingface.co/datasets/zxbsmk/webnovel_cn)
- 用于奖励模型训练: - 用于奖励模型或 DPO 训练
- [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf) - [HH-RLHF (en)](https://huggingface.co/datasets/Anthropic/hh-rlhf)
- [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1) - [Open Assistant (multilingual)](https://huggingface.co/datasets/OpenAssistant/oasst1)
- [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) - [GPT-4 Generated Data (en&zh)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
@@ -111,6 +123,7 @@ huggingface-cli login
- Python 3.8+ 和 PyTorch 1.13.1+ - Python 3.8+ 和 PyTorch 1.13.1+
- 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL - 🤗Transformers, Datasets, Accelerate, PEFT 和 TRL
- sentencepiece 和 tiktoken
- jieba, rouge-chinese 和 nltk (用于评估) - jieba, rouge-chinese 和 nltk (用于评估)
- gradio 和 matplotlib (用于网页端交互) - gradio 和 matplotlib (用于网页端交互)
- uvicorn, fastapi 和 sse-starlette (用于 API) - uvicorn, fastapi 和 sse-starlette (用于 API)
@@ -128,7 +141,6 @@ huggingface-cli login
### 环境搭建(可跳过) ### 环境搭建(可跳过)
```bash ```bash
git lfs install
git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git git clone https://github.com/hiyouga/LLaMA-Efficient-Tuning.git
conda create -n llama_etuning python=3.10 conda create -n llama_etuning python=3.10
conda activate llama_etuning conda activate llama_etuning
@@ -142,24 +154,29 @@ pip install -r requirements.txt
pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl pip install https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl
``` ```
### 浏览器一键微调/测试 ### 浏览器一体化界面
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_web.py CUDA_VISIBLE_DEVICES=0 python src/train_web.py
``` ```
我们极力推荐新手使用浏览器一体化界面,因为它还可以**自动**生成运行所需的命令行脚本。
目前网页 UI 仅支持**单卡训练**。 目前网页 UI 仅支持**单卡训练**。
### 二次预训练 ### 单 GPU 训练
#### 预训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage pt \ --stage pt \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset wiki_demo \ --dataset wiki_demo \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_pt_checkpoint \ --output_dir path_to_pt_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -173,16 +190,17 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
### 指令监督微调 #### 指令监督微调
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \ --stage sft \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--output_dir path_to_sft_checkpoint \ --output_dir path_to_sft_checkpoint \
--overwrite_cache \ --overwrite_cache \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 4 \
@@ -196,42 +214,42 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--fp16 --fp16
``` ```
使用 Baichuan 模型时请指定 `--lora_target W_pack` 参数。 #### 奖励模型训练
### 奖励模型训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage rm \ --stage rm \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset comparison_gpt4_zh \ --dataset comparison_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_rm_checkpoint \ --output_dir path_to_rm_checkpoint \
--per_device_train_batch_size 4 \ --per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \ --gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \ --lr_scheduler_type cosine \
--logging_steps 10 \ --logging_steps 10 \
--save_steps 1000 \ --save_steps 1000 \
--learning_rate 1e-5 \ --learning_rate 1e-6 \
--num_train_epochs 1.0 \ --num_train_epochs 1.0 \
--plot_loss \ --plot_loss \
--fp16 --fp16
``` ```
### RLHF 训练 #### PPO 训练
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage ppo \ --stage ppo \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--do_train \ --do_train \
--dataset alpaca_gpt4_zh \ --dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \ --resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \ --checkpoint_dir path_to_sft_checkpoint \
--reward_model path_to_rm_checkpoint \ --reward_model path_to_rm_checkpoint \
@@ -246,8 +264,35 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--plot_loss --plot_loss
``` ```
#### DPO 训练
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage dpo \
--model_name_or_path path_to_llama_model \
--do_train \
--dataset comparison_gpt4_zh \
--template default \
--finetuning_type lora \
--lora_target q_proj,v_proj \
--resume_lora_training False \
--checkpoint_dir path_to_sft_checkpoint \
--output_dir path_to_dpo_checkpoint \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 4 \
--lr_scheduler_type cosine \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate 1e-5 \
--num_train_epochs 1.0 \
--plot_loss \
--fp16
```
### 多 GPU 分布式训练 ### 多 GPU 分布式训练
#### 使用 Huggingface Accelerate
```bash ```bash
accelerate config # 首先配置分布式环境 accelerate config # 首先配置分布式环境
accelerate launch src/train_bash.py # 参数同上 accelerate launch src/train_bash.py # 参数同上
@@ -281,47 +326,60 @@ use_cpu: false
</details> </details>
### 指标评估BLEU分数和汉语ROUGE分数 #### 使用 DeepSpeed
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ deepspeed --num_gpus 8 --master_port=9901 src/train_bash.py \
--stage sft \ --deepspeed ds_config.json \
--model_name_or_path path_to_your_model \ ... # 参数同上
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
``` ```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128` 参数。 <details><summary>使用 DeepSpeed ZeRO-2 进行全参数微调的 DeepSpeed 配置示例</summary>
### 模型预测 ```json
{
"train_micro_batch_size_per_gpu": "auto",
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"zero_allow_untested_optimizer": true,
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"initial_scale_power": 16,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"zero_optimization": {
"stage": 2,
"allgather_partitions": true,
"allgather_bucket_size": 5e8,
"reduce_scatter": true,
"reduce_bucket_size": 5e8,
"overlap_comm": false,
"contiguous_gradients": true
}
}
```
</details>
### 导出微调后的模型
```bash ```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ python src/export_model.py \
--stage sft \ --model_name_or_path path_to_llama_model \
--model_name_or_path path_to_your_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \ --output_dir path_to_export
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
``` ```
### API 服务 ### API 服务
```bash ```bash
python src/api_demo.py \ python src/api_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
@@ -333,7 +391,7 @@ python src/api_demo.py \
```bash ```bash
python src/cli_demo.py \ python src/cli_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
@@ -343,21 +401,46 @@ python src/cli_demo.py \
```bash ```bash
python src/web_demo.py \ python src/web_demo.py \
--model_name_or_path path_to_your_model \ --model_name_or_path path_to_llama_model \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint --checkpoint_dir path_to_checkpoint
``` ```
### 导出微调模型 ### 指标评估BLEU 分数和汉语 ROUGE 分数)
```bash ```bash
python src/export_model.py \ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--model_name_or_path path_to_your_model \ --stage sft \
--model_name_or_path path_to_llama_model \
--do_eval \
--dataset alpaca_gpt4_zh \
--template default \ --template default \
--finetuning_type lora \ --finetuning_type lora \
--checkpoint_dir path_to_checkpoint \ --checkpoint_dir path_to_checkpoint \
--output_dir path_to_export --output_dir path_to_eval_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
```
我们建议在量化模型的评估中使用 `--per_device_eval_batch_size=1``--max_target_length 128`
### 模型预测
```bash
CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \
--stage sft \
--model_name_or_path path_to_llama_model \
--do_predict \
--dataset alpaca_gpt4_zh \
--template default \
--finetuning_type lora \
--checkpoint_dir path_to_checkpoint \
--output_dir path_to_predict_result \
--per_device_eval_batch_size 8 \
--max_samples 100 \
--predict_with_generate
``` ```
## TODO ## TODO
@@ -378,6 +461,9 @@ python src/export_model.py \
- [Falcon](LICENSE) - [Falcon](LICENSE)
- [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) - [Baichuan](https://huggingface.co/baichuan-inc/baichuan-7B/resolve/main/baichuan-7B%20%E6%A8%A1%E5%9E%8B%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf)
- [InternLM](https://github.com/InternLM/InternLM#open-source-license) - [InternLM](https://github.com/InternLM/InternLM#open-source-license)
- [Qwen](https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/LICENSE)
- [XVERSE](https://github.com/xverse-ai/XVERSE-13B/blob/main/MODEL_LICENSE.pdf)
- [ChatGLM2](https://github.com/THUDM/ChatGLM2-6B/blob/main/MODEL_LICENSE)
## 引用 ## 引用

View File

@@ -1 +0,0 @@
f967a4f6d04a11308a15524aa9a846a19a8d1e83

View File

@@ -1 +0,0 @@
0a4f0d74fd1c5cab2eb6d84a3a3fe669847becd8

View File

@@ -3,8 +3,10 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.21.0 accelerate>=0.21.0
peft>=0.4.0 peft>=0.4.0
trl>=0.4.7 trl>=0.5.0
scipy
sentencepiece sentencepiece
tiktoken
jieba jieba
rouge-chinese rouge-chinese
nltk nltk

View File

@@ -1,19 +1,13 @@
# coding=utf-8
# Implements API for fine-tuned models in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
# Usage: python api_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
# Visit http://localhost:8000/docs for document.
import uvicorn import uvicorn
from llmtuner import ChatModel from llmtuner import ChatModel, create_app
from llmtuner.api.app import create_app
from llmtuner.tuner import get_infer_args
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
print("Visit http://localhost:8000/docs for API document.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,13 +1,8 @@
# coding=utf-8
# Implements stream chat in command line for fine-tuned models.
# Usage: python cli_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
from llmtuner import ChatModel from llmtuner import ChatModel
from llmtuner.tuner import get_infer_args
def main(): def main():
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
history = [] history = []
print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.") print("Welcome to the CLI application, use `clear` to remove the history, use `exit` to exit the application.")

View File

@@ -1,16 +1,8 @@
# coding=utf-8 from llmtuner import export_model
# Exports the fine-tuned model.
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from llmtuner.tuner import get_train_args, load_model_and_tokenizer
def main(): def main():
model_args, _, training_args, finetuning_args, _ = get_train_args() export_model()
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,4 +1,9 @@
# Level: api, webui > chat > tuner > dsets > extras, hparams
from llmtuner.api import create_app
from llmtuner.chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.tuner import export_model, run_exp
from llmtuner.webui import create_ui, create_web_demo
__version__ = "0.1.5" __version__ = "0.1.7"

View File

@@ -0,0 +1 @@
from llmtuner.api.app import create_app

View File

@@ -5,9 +5,8 @@ from contextlib import asynccontextmanager
from sse_starlette import EventSourceResponse from sse_starlette import EventSourceResponse
from typing import List, Tuple from typing import List, Tuple
from llmtuner.tuner import get_infer_args
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat import ChatModel
from llmtuner.api.protocol import ( from llmtuner.api.protocol import (
Role, Role,
Finish, Finish,
@@ -48,15 +47,15 @@ def create_app(chat_model: ChatModel) -> FastAPI:
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse) @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest): async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER: if len(request.messages) < 1 or request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request") raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1] prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content system = prev_messages.pop(0).content
else: else:
prefix = None system = None
history = [] history = []
if len(prev_messages) % 2 == 0: if len(prev_messages) % 2 == 0:
@@ -65,11 +64,11 @@ def create_app(chat_model: ChatModel) -> FastAPI:
history.append([prev_messages[i].content, prev_messages[i+1].content]) history.append([prev_messages[i].content, prev_messages[i+1].content])
if request.stream: if request.stream:
generate = predict(query, history, prefix, request) generate = predict(query, history, system, request)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream")
response, (prompt_length, response_length) = chat_model.chat( response, (prompt_length, response_length) = chat_model.chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
) )
usage = ChatCompletionResponseUsage( usage = ChatCompletionResponseUsage(
@@ -86,7 +85,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest):
choice_data = ChatCompletionResponseStreamChoice( choice_data = ChatCompletionResponseStreamChoice(
index=0, index=0,
delta=DeltaMessage(role=Role.ASSISTANT), delta=DeltaMessage(role=Role.ASSISTANT),
@@ -96,7 +95,7 @@ def create_app(chat_model: ChatModel) -> FastAPI:
yield chunk.json(exclude_unset=True, ensure_ascii=False) yield chunk.json(exclude_unset=True, ensure_ascii=False)
for new_text in chat_model.stream_chat( for new_text in chat_model.stream_chat(
query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens query, history, system, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
): ):
if len(new_text) == 0: if len(new_text) == 0:
continue continue
@@ -122,6 +121,6 @@ def create_app(chat_model: ChatModel) -> FastAPI:
if __name__ == "__main__": if __name__ == "__main__":
chat_model = ChatModel(*get_infer_args()) chat_model = ChatModel()
app = create_app(chat_model) app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)

View File

@@ -1,44 +1,37 @@
import torch import torch
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor from llmtuner.extras.misc import dispatch_model, get_logits_processor
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template_and_fix_tokenizer
from llmtuner.tuner import load_model_and_tokenizer from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
class ChatModel: class ChatModel:
def __init__( def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self, model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
model_args: "ModelArguments",
data_args: "DataArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments"
) -> None:
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model) self.model = dispatch_model(self.model)
self.template = get_template(data_args.template) self.model = self.model.eval() # enable evaluation mode
self.source_prefix = data_args.source_prefix self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.generating_args = generating_args self.system_prompt = data_args.system_prompt
def process_args( def process_args(
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[Dict[str, Any], int]: ) -> Tuple[Dict[str, Any], int]:
prefix = prefix or self.source_prefix system = system or self.system_prompt
prompt = self.template.get_prompt(query, history, prefix, self.tokenizer.eos_token) prompt, _ = self.template.encode_oneturn(
inputs = self.tokenizer([prompt], return_tensors="pt") tokenizer=self.tokenizer, query=query, resp="", history=history, system=system
inputs = inputs.to(self.model.device) )
prompt_length = len(inputs["input_ids"][0]) input_ids = torch.tensor([prompt], device=self.model.device)
prompt_length = len(input_ids[0])
do_sample = input_kwargs.pop("do_sample", None) do_sample = input_kwargs.pop("do_sample", None)
temperature = input_kwargs.pop("temperature", None) temperature = input_kwargs.pop("temperature", None)
@@ -50,12 +43,14 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict( gen_kwargs.update(dict(
input_ids=inputs["input_ids"], input_ids=input_ids,
do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
temperature=temperature or gen_kwargs["temperature"], temperature=temperature or gen_kwargs["temperature"],
top_p=top_p or gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k or gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)),
pad_token_id=self.tokenizer.pad_token_id,
logits_processor=get_logits_processor() logits_processor=get_logits_processor()
)) ))
@@ -74,10 +69,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs)
generation_output = self.model.generate(**gen_kwargs) generation_output = self.model.generate(**gen_kwargs)
outputs = generation_output.tolist()[0][prompt_length:] outputs = generation_output.tolist()[0][prompt_length:]
response = self.tokenizer.decode(outputs, skip_special_tokens=True) response = self.tokenizer.decode(outputs, skip_special_tokens=True)
@@ -89,10 +84,10 @@ class ChatModel:
self, self,
query: str, query: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = None, system: Optional[str] = None,
**input_kwargs **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs)
streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer

View File

@@ -1,48 +1,25 @@
import os import os
import hashlib from typing import TYPE_CHECKING, List, Union
from typing import TYPE_CHECKING, List, Optional
from datasets import Value, concatenate_datasets, interleave_datasets, load_dataset from datasets import concatenate_datasets, interleave_datasets, load_dataset
from llmtuner.dsets.utils import checksum, EXT2TYPE
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from llmtuner.hparams import ModelArguments, DataArguments from llmtuner.hparams import ModelArguments, DataArguments
logger = get_logger(__name__) logger = get_logger(__name__)
EXT2TYPE = {
"csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def get_dataset( def get_dataset(
model_args: "ModelArguments", model_args: "ModelArguments",
data_args: "DataArguments" data_args: "DataArguments"
) -> "Dataset": ) -> Union["Dataset", "IterableDataset"]:
max_samples = data_args.max_samples max_samples = data_args.max_samples
all_datasets: List["Dataset"] = [] # support multiple datasets all_datasets: List[Union["Dataset", "IterableDataset"]] = [] # support multiple datasets
for dataset_attr in data_args.dataset_list: for dataset_attr in data_args.dataset_list:
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
@@ -92,12 +69,11 @@ def get_dataset(
if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name: if getattr(dataset_attr, column_name) and getattr(dataset_attr, column_name) != column_name:
dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name) dataset = dataset.rename_column(getattr(dataset_attr, column_name), column_name)
if dataset_attr.source_prefix: # add prefix if dataset_attr.system_prompt: # add system prompt
features = None
if data_args.streaming: if data_args.streaming:
features = dataset.features dataset = dataset.map(lambda _: {"system": dataset_attr.system_prompt})
features["prefix"] = Value(dtype="string", id=None) else:
dataset = dataset.map(lambda _: {"prefix": dataset_attr.source_prefix}, features=features) dataset = dataset.add_column("system", [dataset_attr.system_prompt] * len(dataset))
all_datasets.append(dataset) all_datasets.append(dataset)
@@ -111,6 +87,6 @@ def get_dataset(
if not data_args.streaming: if not data_args.streaming:
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted" stopping_strategy = "first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted"
return interleave_datasets(all_datasets, stopping_strategy=stopping_strategy) return interleave_datasets(all_datasets, data_args.interleave_probs, stopping_strategy=stopping_strategy)
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy.")

View File

@@ -1,37 +1,43 @@
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal import tiktoken
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Union
from itertools import chain from itertools import chain
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.template import get_template from llmtuner.extras.template import get_template_and_fix_tokenizer
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from transformers.tokenization_utils import PreTrainedTokenizer from transformers.tokenization_utils import PreTrainedTokenizer
from llmtuner.hparams import DataArguments from llmtuner.hparams import DataArguments
def preprocess_dataset( def preprocess_dataset(
dataset: "Dataset", dataset: Union["Dataset", "IterableDataset"],
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo"] stage: Literal["pt", "sft", "rm", "ppo"]
) -> "Dataset": ) -> Union["Dataset", "IterableDataset"]:
column_names = list(dataset.column_names) column_names = list(next(iter(dataset)).keys())
template = get_template(data_args.template) template = get_template_and_fix_tokenizer(data_args.template, tokenizer)
def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]:
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
query, response = examples["prompt"][i], examples["response"][i] query, response = examples["prompt"][i], examples["response"][i]
query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query
history = examples["history"][i] if "history" in examples else None history = examples["history"][i] if "history" in examples else None
prefix = examples["prefix"][i] if "prefix" in examples else None system = examples["system"][i] if "system" in examples else None
yield query, response, history, prefix yield query, response, history, system
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build grouped texts with format `<bos> X1 X2 X3 ...` (without <eos>) # build grouped texts with format `X1 X2 X3 ...` (without <eos>)
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False) if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
block_size = data_args.max_source_length block_size = data_args.max_source_length
@@ -42,33 +48,28 @@ def preprocess_dataset(
k: [t[i: i + block_size] for i in range(0, total_length, block_size)] k: [t[i: i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items() for k, t in concatenated_examples.items()
} }
result["labels"] = result["input_ids"].copy()
return result return result
def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>` # build inputs with format `<bos> X Y <eos>` and labels with format `<ignore> ... <ignore> Y <eos>`
# for input with history, we build multiple input-label pairs just like: # for multiturn examples, we only mask the prompt part in each prompt-response pair.
# https://github.com/lm-sys/FastChat/blob/f17c092f64840fa6354ed52789dccb2daa793d0b/fastchat/train/train.py#L112
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
max_length = data_args.max_source_length + data_args.max_target_length max_length = data_args.max_source_length + data_args.max_target_length
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
input_ids, labels = [], [] input_ids, labels = [], []
for i, (query_i, resp_i) in enumerate(template.get_dialog(query, response, history, prefix)): for source_ids, target_ids in template.encode_multiturn(tokenizer, query, response, history, system):
source_ids = tokenizer.encode(text=query_i, add_special_tokens=(i == 0))
target_ids = tokenizer.encode(text=resp_i, add_special_tokens=False)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
if len(target_ids) > data_args.max_target_length - 1: # eos token if len(target_ids) > data_args.max_target_length:
target_ids = target_ids[:data_args.max_target_length - 1] target_ids = target_ids[:data_args.max_target_length]
if len(input_ids) + len(source_ids) + len(target_ids) + 1 > max_length: if len(input_ids) + len(source_ids) + len(target_ids) > max_length:
break break
input_ids += source_ids + target_ids + [tokenizer.eos_token_id] input_ids += source_ids + target_ids
labels += [IGNORE_INDEX] * len(source_ids) + target_ids + [tokenizer.eos_token_id] labels += [IGNORE_INDEX] * len(source_ids) + target_ids
model_inputs["input_ids"].append(input_ids) model_inputs["input_ids"].append(input_ids)
model_inputs["attention_mask"].append([1] * len(input_ids)) model_inputs["attention_mask"].append([1] * len(input_ids))
@@ -77,14 +78,11 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: def preprocess_unsupervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
# build inputs with format `<bos> X` and labels with format `<bos> Y` # build inputs with format `<bos> X` and labels with format `Y <eos>`
model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} model_inputs = {"input_ids": [], "attention_mask": [], "labels": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) source_ids, target_ids = template.encode_oneturn(tokenizer, query, response, history, system)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True)
target_ids = tokenizer.encode(text=response, add_special_tokens=True)
if len(source_ids) > data_args.max_source_length: if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length] source_ids = source_ids[:data_args.max_source_length]
@@ -98,43 +96,39 @@ def preprocess_dataset(
return model_inputs return model_inputs
def preprocess_pairwise_dataset(examples): def preprocess_pairwise_dataset(examples):
# build input pairs with format `<bos> X Y1 <eos>` and `<bos> X Y2 <eos>` # build input pairs with format `<bos> X`, `Y1 <eos>` and `Y2 <eos>`
model_inputs = {"accept_ids": [], "reject_ids": []} model_inputs = {"prompt_ids": [], "chosen_ids": [], "rejected_ids": []}
for query, response, history, prefix in construct_example(examples): for query, response, history, system in construct_example(examples):
prompt = template.get_prompt(query, history, prefix, tokenizer.eos_token) prompt_ids, chosen_ids = template.encode_oneturn(tokenizer, query, response[0], history, system)
_, rejected_ids = template.encode_oneturn(tokenizer, query, response[1], history, system)
source_ids = tokenizer.encode(text=prompt, add_special_tokens=True) if len(prompt_ids) > data_args.max_source_length:
accept_ids = tokenizer.encode(text=response[0], add_special_tokens=False) prompt_ids = prompt_ids[:data_args.max_source_length]
reject_ids = tokenizer.encode(text=response[1], add_special_tokens=False) if len(chosen_ids) > data_args.max_target_length:
chosen_ids = chosen_ids[:data_args.max_target_length]
if len(rejected_ids) > data_args.max_target_length:
rejected_ids = rejected_ids[:data_args.max_target_length]
if len(source_ids) > data_args.max_source_length: model_inputs["prompt_ids"].append(prompt_ids)
source_ids = source_ids[:data_args.max_source_length] model_inputs["chosen_ids"].append(chosen_ids)
if len(accept_ids) > data_args.max_target_length - 1: # eos token model_inputs["rejected_ids"].append(rejected_ids)
accept_ids = accept_ids[:data_args.max_target_length - 1]
if len(reject_ids) > data_args.max_target_length - 1: # eos token
reject_ids = reject_ids[:data_args.max_target_length - 1]
accept_ids = source_ids + accept_ids + [tokenizer.eos_token_id]
reject_ids = source_ids + reject_ids + [tokenizer.eos_token_id]
model_inputs["accept_ids"].append(accept_ids)
model_inputs["reject_ids"].append(reject_ids)
return model_inputs return model_inputs
def print_supervised_dataset_example(example): def print_supervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
print("label_ids:\n{}".format(example["labels"])) print("label_ids:\n{}".format(example["labels"]))
print("labels:\n{}".format( print("labels:\n{}".format(tokenizer.decode([
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], token_id if token_id != IGNORE_INDEX else tokenizer.pad_token_id for token_id in example["labels"]
skip_special_tokens=False) ], skip_special_tokens=False)))
))
def print_pairwise_dataset_example(example): def print_pairwise_dataset_example(example):
print("accept_ids:\n{}".format(example["accept_ids"])) print("prompt_ids:\n{}".format(example["prompt_ids"]))
print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) print("prompt:\n{}".format(tokenizer.decode(example["prompt_ids"], skip_special_tokens=False)))
print("reject_ids:\n{}".format(example["reject_ids"])) print("chosen_ids:\n{}".format(example["chosen_ids"]))
print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) print("chosen:\n{}".format(tokenizer.decode(example["chosen_ids"], skip_special_tokens=False)))
print("rejected_ids:\n{}".format(example["rejected_ids"]))
print("rejected:\n{}".format(tokenizer.decode(example["rejected_ids"], skip_special_tokens=False)))
def print_unsupervised_dataset_example(example): def print_unsupervised_dataset_example(example):
print("input_ids:\n{}".format(example["input_ids"])) print("input_ids:\n{}".format(example["input_ids"]))
@@ -173,8 +167,5 @@ def preprocess_dataset(
**kwargs **kwargs
) )
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset))) print_function(next(iter(dataset)))
return dataset return dataset

View File

@@ -1,15 +1,59 @@
from typing import TYPE_CHECKING, Dict import hashlib
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from datasets import Dataset from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]: logger = get_logger(__name__)
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio) EXT2TYPE = {
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} "csv": "csv",
"json": "json",
"jsonl": "json",
"txt": "text"
}
def checksum(data_files: List[str], file_sha1: Optional[str] = None) -> None:
if file_sha1 is None:
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.")
return
if len(data_files) != 1:
logger.warning("Checksum failed: too many files.")
return
with open(data_files[0], "rb") as f:
sha1 = hashlib.sha1(f.read()).hexdigest()
if sha1 != file_sha1:
logger.warning("Checksum failed: mismatched SHA-1 hash value at {}.".format(data_files[0]))
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else: else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset} return {"train_dataset": dataset}
else: # do_eval or do_predict else: # do_eval or do_predict
return {"eval_dataset": dataset} return {"eval_dataset": dataset}

View File

@@ -5,67 +5,124 @@ from typing import TYPE_CHECKING
from datetime import timedelta from datetime import timedelta
from transformers import TrainerCallback from transformers import TrainerCallback
from transformers.trainer_utils import has_length
from llmtuner.extras.constants import LOG_FILE_NAME
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import TrainingArguments, TrainerState, TrainerControl from transformers import TrainingArguments, TrainerState, TrainerControl
logger = get_logger(__name__)
class LogCallback(TrainerCallback): class LogCallback(TrainerCallback):
def __init__(self, runner=None): def __init__(self, runner=None):
self.runner = runner self.runner = runner
self.in_training = False
self.start_time = time.time() self.start_time = time.time()
self.tracker = {} self.cur_steps = 0
self.max_steps = 0
self.elapsed_time = ""
self.remaining_time = ""
def timing(self):
cur_time = time.time()
elapsed_time = cur_time - self.start_time
avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0
remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step
self.elapsed_time = str(timedelta(seconds=int(elapsed_time)))
self.remaining_time = str(timedelta(seconds=int(remaining_time)))
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
""" """
self.start_time = time.time() if state.is_local_process_zero:
self.in_training = True
self.start_time = time.time()
self.max_steps = state.max_steps
if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)):
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, LOG_FILE_NAME))
def on_step_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of a training step. If using gradient accumulation, one training step Event called at the end of training.
might take several inputs.
""" """
if self.runner is not None and self.runner.aborted: if state.is_local_process_zero:
control.should_epoch_stop = True self.in_training = False
control.should_training_stop = True self.cur_steps = 0
self.max_steps = 0
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
""" """
if self.runner is not None and self.runner.aborted: if state.is_local_process_zero and self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called at the end of a training step.
"""
if state.is_local_process_zero:
self.cur_steps = state.global_step
self.timing()
if self.runner is not None and self.runner.aborted:
control.should_epoch_stop = True
control.should_training_stop = True
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after an evaluation phase.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs):
r"""
Event called after a successful prediction.
"""
if state.is_local_process_zero and not self.in_training:
self.cur_steps = 0
self.max_steps = 0
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None:
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
""" """
if not state.is_world_process_zero: if not state.is_local_process_zero:
return return
cur_time = time.time() logs = dict(
cur_steps = state.log_history[-1].get("step") current_steps=self.cur_steps,
elapsed_time = cur_time - self.start_time total_steps=self.max_steps,
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 loss=state.log_history[-1].get("loss", None),
remaining_steps = state.max_steps - cur_steps eval_loss=state.log_history[-1].get("eval_loss", None),
remaining_time = remaining_steps * avg_time_per_step predict_loss=state.log_history[-1].get("predict_loss", None),
self.tracker = { reward=state.log_history[-1].get("reward", None),
"current_steps": cur_steps, learning_rate=state.log_history[-1].get("learning_rate", None),
"total_steps": state.max_steps, epoch=state.log_history[-1].get("epoch", None),
"loss": state.log_history[-1].get("loss", None), percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100,
"eval_loss": state.log_history[-1].get("eval_loss", None), elapsed_time=self.elapsed_time,
"predict_loss": state.log_history[-1].get("predict_loss", None), remaining_time=self.remaining_time
"reward": state.log_history[-1].get("reward", None), )
"learning_rate": state.log_history[-1].get("learning_rate", None),
"epoch": state.log_history[-1].get("epoch", None),
"percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
"elapsed_time": str(timedelta(seconds=int(elapsed_time))),
"remaining_time": str(timedelta(seconds=int(remaining_time)))
}
os.makedirs(args.output_dir, exist_ok=True) os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
f.write(json.dumps(self.tracker) + "\n") f.write(json.dumps(logs) + "\n")
def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r"""
Event called after a prediction step.
"""
eval_dataloader = kwargs.pop("eval_dataloader", None)
if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training:
if self.max_steps == 0:
self.max_steps = len(eval_dataloader)
self.cur_steps += 1
self.timing()

View File

@@ -1,13 +1,23 @@
IGNORE_INDEX = -100 IGNORE_INDEX = -100
LOG_FILE_NAME = "trainer_log.jsonl"
VALUE_HEAD_FILE_NAME = "value_head.bin" VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json" FINETUNING_ARGS_NAME = "finetuning_args.json"
LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"] # for LLaMA, BLOOM and Falcon settings LAYERNORM_NAMES = ["norm", "ln_f", "ln_attn", "ln_mlp"]
METHODS = ["full", "freeze", "lora"] METHODS = ["full", "freeze", "lora"]
STAGES = [
"SFT",
"Reward Modeling",
"PPO",
"DPO",
"Pre-Training"
]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"LLaMA-7B": "huggyllama/llama-7b", "LLaMA-7B": "huggyllama/llama-7b",
"LLaMA-13B": "huggyllama/llama-13b", "LLaMA-13B": "huggyllama/llama-13b",
@@ -19,29 +29,50 @@ SUPPORTED_MODELS = {
"LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf", "LLaMA2-7B-Chat": "meta-llama/Llama-2-7b-chat-hf",
"LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf", "LLaMA2-13B-Chat": "meta-llama/Llama-2-13b-chat-hf",
"LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf", "LLaMA2-70B-Chat": "meta-llama/Llama-2-70b-chat-hf",
"ChineseLLaMA2-7B": "ziqingyang/chinese-llama-2-7b",
"ChineseLLaMA2-13B": "ziqingyang/chinese-llama-2-13b",
"ChineseLLaMA2-7B-Chat": "ziqingyang/chinese-alpaca-2-7b",
"ChineseLLaMA2-13B-Chat": "ziqingyang/chinese-alpaca-2-13b",
"BLOOM-560M": "bigscience/bloom-560m", "BLOOM-560M": "bigscience/bloom-560m",
"BLOOM-3B": "bigscience/bloom-3b", "BLOOM-3B": "bigscience/bloom-3b",
"BLOOM-7B1": "bigscience/bloom-7b1", "BLOOM-7B1": "bigscience/bloom-7b1",
"BLOOMZ-560M": "bigscience/bloomz-560m", "BLOOMZ-560M": "bigscience/bloomz-560m",
"BLOOMZ-3B": "bigscience/bloomz-3b", "BLOOMZ-3B": "bigscience/bloomz-3b",
"BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt", "BLOOMZ-7B1-mt": "bigscience/bloomz-7b1-mt",
"Falcon-7B-Base": "tiiuae/falcon-7b", "Falcon-7B": "tiiuae/falcon-7b",
"Falcon-7B-Chat": "tiiuae/falcon-7b-instruct", "Falcon-7B-Chat": "tiiuae/falcon-7b-instruct",
"Falcon-40B-Base": "tiiuae/falcon-40b", "Falcon-40B": "tiiuae/falcon-40b",
"Falcon-40B-Chat": "tiiuae/falcon-40b-instruct", "Falcon-40B-Chat": "tiiuae/falcon-40b-instruct",
"Baichuan-7B": "baichuan-inc/Baichuan-7B", "Baichuan-7B": "baichuan-inc/Baichuan-7B",
"Baichuan-13B-Base": "baichuan-inc/Baichuan-13B-Base", "Baichuan-13B": "baichuan-inc/Baichuan-13B-Base",
"Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat", "Baichuan-13B-Chat": "baichuan-inc/Baichuan-13B-Chat",
"InternLM-7B-Base": "internlm/internlm-7b", "InternLM-7B": "internlm/internlm-7b",
"InternLM-7B-Chat": "internlm/internlm-chat-7b" "InternLM-7B-Chat": "internlm/internlm-chat-7b",
"Qwen-7B": "Qwen/Qwen-7B",
"Qwen-7B-Chat": "Qwen/Qwen-7B-Chat",
"XVERSE-13B": "xverse/XVERSE-13B",
"ChatGLM2-6B-Chat": "THUDM/chatglm2-6b"
} }
DEFAULT_MODULE = { DEFAULT_MODULE = {
"LLaMA": "q_proj,v_proj", "LLaMA": "q_proj,v_proj",
"LLaMA2": "q_proj,v_proj", "LLaMA2": "q_proj,v_proj",
"ChineseLLaMA2": "q_proj,v_proj",
"BLOOM": "query_key_value", "BLOOM": "query_key_value",
"BLOOMZ": "query_key_value", "BLOOMZ": "query_key_value",
"Falcon": "query_key_value", "Falcon": "query_key_value",
"Baichuan": "W_pack", "Baichuan": "W_pack",
"InternLM": "q_proj,v_proj" "InternLM": "q_proj,v_proj",
"Qwen": "c_attn",
"XVERSE": "q_proj,v_proj",
"ChatGLM2": "query_key_value"
}
DEFAULT_TEMPLATE = {
"LLaMA2": "llama2",
"ChineseLLaMA2": "llama2_zh",
"Baichuan": "baichuan",
"InternLM": "intern",
"Qwen": "chatml",
"ChatGLM2": "chatglm2"
} }

View File

@@ -8,6 +8,9 @@ class LoggerHandler(logging.Handler):
super().__init__() super().__init__()
self.log = "" self.log = ""
def reset(self):
self.log = ""
def emit(self, record): def emit(self, record):
if record.name == "httpx": if record.name == "httpx":
return return

View File

@@ -1,8 +1,6 @@
import torch import torch
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from llmtuner.extras.constants import LAYERNORM_NAMES from llmtuner.extras.constants import LAYERNORM_NAMES
@@ -30,19 +28,9 @@ class AverageMeter:
self.avg = self.sum / self.count self.avg = self.sum / self.count
# Avoids runtime error in model.generate(do_sample=True).
class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[..., 0] = 1.0
return scores
def get_logits_processor() -> LogitsProcessorList: def get_logits_processor() -> LogitsProcessorList:
logits_processor = LogitsProcessorList() logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor()) logits_processor.append(InfNanRemoveLogitsProcessor())
return logits_processor return logits_processor
@@ -77,7 +65,6 @@ def prepare_model_for_training(
use_gradient_checkpointing: Optional[bool] = True, use_gradient_checkpointing: Optional[bool] = True,
layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
) -> "PreTrainedModel": ) -> "PreTrainedModel":
for name, param in model.named_parameters(): for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
param.data = param.data.to(torch.float32) param.data = param.data.to(torch.float32)
@@ -94,9 +81,6 @@ def prepare_model_for_training(
model.config.use_cache = False # turn off when gradient checkpointing is enabled model.config.use_cache = False # turn off when gradient checkpointing is enabled
if finetuning_type != "full" and hasattr(model, output_layer_name): if finetuning_type != "full" and hasattr(model, output_layer_name):
if hasattr(model, "config") and hasattr(model.config, "pretraining_tp"):
model.config.pretraining_tp = 1 # disable TP for LoRA (https://github.com/huggingface/peft/pull/728)
output_layer: torch.nn.Linear = getattr(model, output_layer_name) output_layer: torch.nn.Linear = getattr(model, output_layer_name)
input_dtype = output_layer.weight.dtype input_dtype = output_layer.weight.dtype
@@ -124,6 +108,9 @@ def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel":
Dispatches a pre-trained model to GPUs with balanced memory. Dispatches a pre-trained model to GPUs with balanced memory.
Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803
""" """
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing
return model
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
from accelerate import dispatch_model from accelerate import dispatch_model
from accelerate.utils import infer_auto_device_map, get_balanced_memory from accelerate.utils import infer_auto_device_map, get_balanced_memory

View File

@@ -1,92 +1,230 @@
from typing import Dict, List, Optional, Tuple import tiktoken
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from llmtuner.extras.logging import get_logger
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
logger = get_logger(__name__)
@dataclass @dataclass
class Template: class Template:
prefix: str prefix: List[Union[str, Dict[str, str]]]
prompt: str prompt: List[Union[str, Dict[str, str]]]
sep: str system: str
sep: List[Union[str, Dict[str, str]]]
stop_words: List[str]
use_history: bool use_history: bool
def get_prompt( def encode_oneturn(
self, self,
tokenizer: "PreTrainedTokenizer",
query: str, query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "", system: Optional[str] = None
eos_token: Optional[str] = "</s>" ) -> Tuple[List[int], List[int]]:
) -> str:
r""" r"""
Returns a string containing prompt without response. Returns a single pair of token ids representing prompt and response respectively.
""" """
return eos_token.join(map(lambda x: x[0] + x[1], self._format_example(query, history, prefix))) system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
prompt_ids = []
for query_ids, resp_ids in encoded_pairs[:-1]:
prompt_ids = prompt_ids + query_ids + resp_ids
prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1]
return prompt_ids, answer_ids
def get_dialog( def encode_multiturn(
self,
tokenizer: "PreTrainedTokenizer",
query: str,
resp: str,
history: Optional[List[Tuple[str, str]]] = None,
system: Optional[str] = None
) -> List[Tuple[List[int], List[int]]]:
r"""
Returns multiple pairs of token ids representing prompts and responses respectively.
"""
system, history = self._format(query, resp, history, system)
encoded_pairs = self._encode(tokenizer, system, history)
return encoded_pairs
def _format(
self, self,
query: str, query: str,
resp: str, resp: str,
history: Optional[List[Tuple[str, str]]] = None, history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = "" system: Optional[str] = None
) -> List[Tuple[str, str]]: ) -> Tuple[str, List[Tuple[str, str]]]:
r""" r"""
Returns a list containing prompt-response pairs. Aligns inputs to the standard format.
""" """
result = self._format_example(query, history, prefix) system = system or self.system # use system if provided
result[-1][-1] = resp
return result
def _format_example(
self,
query: str,
history: Optional[List[Tuple[str, str]]] = None,
prefix: Optional[str] = ""
) -> List[Tuple[str, str]]:
prefix = prefix or self.prefix # use prefix if provided
prefix = prefix + self.sep if prefix else "" # add separator for non-empty prefix
history = history if (history and self.use_history) else [] history = history if (history and self.use_history) else []
history = history + [(query, "")] history = history + [(query, resp)]
return [ return system, history
[(self.sep if i else prefix) + self.prompt.format(query=q), r]
for i, (q, r) in enumerate(history) def _get_special_ids(
] self,
tokenizer: "PreTrainedTokenizer"
) -> Tuple[List[int], List[int]]:
if (
tokenizer.bos_token_id is not None
and getattr(tokenizer, "add_bos_token", True)
): # baichuan-13b has no bos token
bos_ids = [tokenizer.bos_token_id]
else:
bos_ids = [] # bos token is optional
if tokenizer.eos_token_id is not None:
eos_ids = [tokenizer.eos_token_id]
else:
raise ValueError("EOS token is required.")
return bos_ids, eos_ids
def _encode(
self,
tokenizer: "PreTrainedTokenizer",
system: str,
history: List[Tuple[str, str]]
) -> List[Tuple[List[int], List[int]]]:
r"""
Encodes formatted inputs to pairs of token ids.
Turn 0: bos + prefix + sep + query resp + eos
Turn t: sep + bos + query resp + eos
"""
bos_ids, eos_ids = self._get_special_ids(tokenizer)
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system)
if len(prefix_ids) != 0: # has prefix
prefix_ids = bos_ids + prefix_ids + sep_ids
else:
prefix_ids = bos_ids
else:
prefix_ids = sep_ids + bos_ids
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx))
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
def _convert_inputs_to_ids(
self,
tokenizer: "PreTrainedTokenizer",
context: List[Union[str, Dict[str, str]]],
system: Optional[str] = None,
query: Optional[str] = None,
idx: Optional[str] = None
) -> List[int]:
r"""
Converts context to token ids.
"""
if isinstance(getattr(tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=False)
token_ids = []
for elem in context:
if isinstance(elem, str):
elem = elem.replace("{{system}}", system, 1) if system is not None else elem
elem = elem.replace("{{query}}", query, 1) if query is not None else elem
elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem
token_ids = token_ids + tokenizer.encode(elem, **kwargs)
elif isinstance(elem, dict):
token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))]
else:
raise NotImplementedError
return token_ids
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
def _format_example( def _encode(
self, self,
query: str, tokenizer: "PreTrainedTokenizer",
history: Optional[List[Tuple[str, str]]] = None, system: str,
prefix: Optional[str] = "" history: List[Tuple[str, str]]
) -> List[Tuple[str, str]]: ) -> List[Tuple[List[int], List[int]]]:
prefix = prefix or self.prefix # use prefix if provided r"""
prefix = prefix if prefix.startswith("<<SYS>>") else "<<SYS>>\n{}\n<</SYS>>\n\n".format(prefix) Encodes formatted inputs to pairs of token ids.
history = history if (history and self.use_history) else [] Turn 0: bos + prefix + query resp + eos
history = history + [(query, "")] Turn t: bos + query resp + eos
return [ """
[(self.sep if i else "") + self.prompt.format(query=(q if i else prefix + q)), r] bos_ids, eos_ids = self._get_special_ids(tokenizer)
for i, (q, r) in enumerate(history) encoded_pairs = []
] for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0: # llama2 template has no sep_ids
query = self.prefix[0].replace("{{system}}", system) + query
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
return encoded_pairs
templates: Dict[str, Template] = {} templates: Dict[str, Template] = {}
def register_template(name: str, prefix: str, prompt: str, sep: str, use_history: bool) -> None: def register_template(
template_class = Llama2Template if name == "llama2" else Template name: str,
prefix: List[Union[str, Dict[str, str]]],
prompt: List[Union[str, Dict[str, str]]],
system: str,
sep: List[Union[str, Dict[str, str]]],
stop_words: Optional[List[str]] = [],
use_history: Optional[bool] = True
) -> None:
template_class = Llama2Template if "llama2" in name else Template
templates[name] = template_class( templates[name] = template_class(
prefix=prefix, prefix=prefix,
prompt=prompt, prompt=prompt,
system=system,
sep=sep, sep=sep,
stop_words=stop_words,
use_history=use_history use_history=use_history
) )
def get_template(name: str) -> Template: def get_template_and_fix_tokenizer(
name: str,
tokenizer: "PreTrainedTokenizer"
) -> Template:
template = templates.get(name, None) template = templates.get(name, None)
assert template is not None, "Template {} does not exist.".format(name) assert template is not None, "Template {} does not exist.".format(name)
additional_special_tokens = template.stop_words
if len(template.stop_words): # inplace method
if tokenizer.eos_token_id is not None:
additional_special_tokens.append(tokenizer.eos_token)
tokenizer.eos_token = additional_special_tokens[0] # use the first stop word as eos token
additional_special_tokens.pop(0)
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
if tokenizer.eos_token_id is None:
tokenizer.eos_token = "<|endoftext|>"
logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token))
tokenizer.add_special_tokens(dict(additional_special_tokens=additional_special_tokens))
return template return template
@@ -95,9 +233,12 @@ Supports language model inference without histories.
""" """
register_template( register_template(
name="vanilla", name="vanilla",
prefix="", prefix=[],
prompt="{query}", prompt=[
sep="", "{{query}}"
],
system="",
sep=[],
use_history=False use_history=False
) )
@@ -107,11 +248,19 @@ Default template.
""" """
register_template( register_template(
name="default", name="default",
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "{{system}}"
prompt="Human: {query}\nAssistant: ", ],
sep="\n", prompt=[
use_history=True "Human: {{query}}\nAssistant: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[
"\n"
]
) )
@@ -122,17 +271,40 @@ Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf
""" """
register_template( register_template(
name="llama2", name="llama2",
prefix="<<SYS>>\nYou are a helpful, respectful and honest assistant. " prefix=[
"Always answer as helpfully as possible, while being safe. " "<<SYS>>\n{{system}}\n<</SYS>>\n\n"
"Your answers should not include any harmful, unethical, " ],
"racist, sexist, toxic, dangerous, or illegal content. " prompt=[
"Please ensure that your responses are socially unbiased and positive in nature.\n" "[INST] {{query}} [/INST] "
"If a question does not make any sense, or is not factually coherent, " ],
"explain why instead of answering something not correct. " system=(
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n", "You are a helpful, respectful and honest assistant. "
prompt="[INST] {query} [/INST] ", "Always answer as helpfully as possible, while being safe. "
sep="<s>", "Your answers should not include any harmful, unethical, "
use_history=True "racist, sexist, toxic, dangerous, or illegal content. "
"Please ensure that your responses are socially unbiased and positive in nature.\n"
"If a question does not make any sense, or is not factually coherent, "
"explain why instead of answering something not correct. "
"If you don't know the answer to a question, please don't share false information."
),
sep=[]
)
r"""
Supports: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2
https://huggingface.co/ziqingyang/chinese-alpaca-2-7b
"""
register_template(
name="llama2_zh",
prefix=[
"<<SYS>>\n{{system}}\n<</SYS>>\n\n"
],
prompt=[
"[INST] {{query}} [/INST] "
],
system="You are a helpful assistant. 你是一个乐于助人的助手。",
sep=[]
) )
@@ -142,11 +314,19 @@ Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff
""" """
register_template( register_template(
name="alpaca", name="alpaca",
prefix="Below is an instruction that describes a task. " prefix=[
"Write a response that appropriately completes the request.", "{{system}}"
prompt="### Instruction:\n{query}\n\n### Response:\n", ],
sep="\n\n", prompt=[
use_history=True "### Instruction:\n{{query}}\n\n### Response:\n"
],
system=(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request."
),
sep=[
"\n\n"
]
) )
@@ -156,11 +336,17 @@ Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1
""" """
register_template( register_template(
name="vicuna", name="vicuna",
prefix="A chat between a curious user and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the user's questions.", "{{system}}"
prompt="USER: {query} ASSISTANT: ", ],
sep="", prompt=[
use_history=True "USER: {{query}} ASSISTANT: "
],
system=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
sep=[]
) )
@@ -169,10 +355,16 @@ Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B
""" """
register_template( register_template(
name="belle", name="belle",
prefix="", prefix=[
prompt="Human: {query}\n\nBelle: ", "{{system}}"
sep="\n\n", ],
use_history=True prompt=[
"Human: {{query}}\n\nBelle: "
],
system="",
sep=[
"\n\n"
]
) )
@@ -181,10 +373,16 @@ Supports: https://github.com/CVI-SZU/Linly
""" """
register_template( register_template(
name="linly", name="linly",
prefix="", prefix=[
prompt="User: {query}\nBot: ", "{{system}}"
sep="\n", ],
use_history=True prompt=[
"User: {{query}}\nBot: "
],
system="",
sep=[
"\n"
]
) )
@@ -193,10 +391,16 @@ Supports: https://github.com/Neutralzz/BiLLa
""" """
register_template( register_template(
name="billa", name="billa",
prefix="", prefix=[
prompt="Human: {query}\nAssistant: ", "{{system}}"
sep="\n", ],
use_history=True prompt=[
"Human: {{query}}\nAssistant: "
],
system="",
sep=[
"\n"
]
) )
@@ -205,10 +409,19 @@ Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1
""" """
register_template( register_template(
name="ziya", name="ziya",
prefix="", prefix=[
prompt="<human>:{query}\n<bot>:", "{{system}}"
sep="\n", ],
use_history=True prompt=[
{"token": "<human>"},
":{{query}}\n",
{"token": "<bot>"},
":"
],
system="",
sep=[
"\n"
]
) )
@@ -217,11 +430,19 @@ Supports: https://huggingface.co/qhduan/aquilachat-7b
""" """
register_template( register_template(
name="aquila", name="aquila",
prefix="A chat between a curious human and an artificial intelligence assistant. " prefix=[
"The assistant gives helpful, detailed, and polite answers to the human's questions.", "{{system}}"
prompt="Human: {query}###Assistant: ", ],
sep="###", prompt=[
use_history=True "Human: {{query}}###Assistant: "
],
system=(
"A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions."
),
sep=[
"###"
]
) )
@@ -230,10 +451,22 @@ Supports: https://huggingface.co/internlm/internlm-chat-7b
""" """
register_template( register_template(
name="intern", name="intern",
prefix="", prefix=[
prompt="<|User|>:{query}<eoh>\n<|Bot|>:", "{{system}}"
sep="<eoa>\n", ],
use_history=True prompt=[
"<|User|>:{{query}}",
{"token": "<eoh>"},
"\n<|Bot|>:"
],
system="",
sep=[
"\n"
],
stop_words=[
"</s>", # internlm cannot replace eos token
"<eoa>"
]
) )
@@ -242,10 +475,19 @@ Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat
""" """
register_template( register_template(
name="baichuan", name="baichuan",
prefix="", prefix=[
prompt="<reserved_102>{query}<reserved_103>", "{{system}}",
sep="", {"token": "<reserved_102>"} # user token
use_history=True ],
prompt=[
"{{query}}",
{"token": "<reserved_103>"} # assistant token
],
system="",
sep=[],
stop_words=[
"<reserved_102>" # user token
]
) )
@@ -255,8 +497,71 @@ Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha
""" """
register_template( register_template(
name="starchat", name="starchat",
prefix="<|system|>\n", prefix=[
prompt="<|user|>\n{query}<|end|>\n<|assistant|>\n", {"token": "<|system|>"},
sep="<|end|>\n", "\n{{system}}",
use_history=True {"token": "<|end|>"}
],
prompt=[
{"token": "<|user|>"},
"\n{{query}}",
{"token": "<|end|>"},
"\n",
{"token": "<|assistant|>"}
],
system="",
sep=[
"\n"
],
stop_words=[
"<|end|>"
]
)
r"""
Supports: https://huggingface.co/Qwen/Qwen-7B-Chat
"""
register_template(
name="chatml",
prefix=[
{"token": "<|im_start|>"},
"system\n{{system}}",
{"token": "<|im_end|>"}
],
prompt=[
{"token": "<|im_start|>"},
"user\n{{query}}",
{"token": "<|im_end|>"},
"\n",
{"token": "<|im_start|>"},
"assistant\n"
],
system="You are a helpful assistant.",
sep=[
"\n"
],
stop_words=[
"<|im_end|>"
]
)
r"""
Supports: https://huggingface.co/THUDM/chatglm2-6b
"""
register_template(
name="chatglm2",
prefix=[
{"token": "[gMASK]"},
{"token": "sop"},
"{{system}}"
],
prompt=[
"[Round {{idx}}]\n\n问:{{query}}\n\n答:"
],
system="",
sep=[
"\n\n"
]
) )

View File

@@ -10,7 +10,7 @@ class DatasetAttr:
load_from: str load_from: str
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_sha1: Optional[str] = None dataset_sha1: Optional[str] = None
source_prefix: Optional[str] = None system_prompt: Optional[str] = None
def __repr__(self) -> str: def __repr__(self) -> str:
return self.dataset_name return self.dataset_name
@@ -24,7 +24,7 @@ class DatasetAttr:
@dataclass @dataclass
class DataArguments: class DataArguments:
""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field( template: str = field(
@@ -54,6 +54,10 @@ class DataArguments:
default="concat", default="concat",
metadata={"help": "Strategy to use in dataset mixing."} metadata={"help": "Strategy to use in dataset mixing."}
) )
interleave_probs: Optional[str] = field(
default=None,
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."}
)
overwrite_cache: Optional[bool] = field( overwrite_cache: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Overwrite the cached training and evaluation sets."} metadata={"help": "Overwrite the cached training and evaluation sets."}
@@ -82,13 +86,13 @@ class DataArguments:
default=True, default=True,
metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
) )
source_prefix: Optional[str] = field( system_prompt: Optional[str] = field(
default=None, default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} metadata={"help": "System prompt to add before the user query. Use `|` to separate multiple prompts in training."}
) )
dev_ratio: Optional[float] = field( val_size: Optional[float] = field(
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
) )
def init_for_training(self): # support mixing multiple datasets def init_for_training(self): # support mixing multiple datasets
@@ -96,12 +100,12 @@ class DataArguments:
with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f:
dataset_info = json.load(f) dataset_info = json.load(f)
if self.source_prefix is not None: prompt_list = self.system_prompt.split("|") if self.system_prompt else [None]
prefix_list = self.source_prefix.split("|") prompt_list = prompt_list * (len(dataset_names) // len(prompt_list))
prefix_list = prefix_list * len(dataset_names) if len(prefix_list) == 1 else prefix_list assert len(prompt_list) == len(dataset_names), "Number of system prompts should be equal to datasets or 1."
assert len(prefix_list) == len(dataset_names), "The number of prefixes should be either identical with datasets or 1."
else: if self.interleave_probs is not None:
prefix_list = [None] * len(dataset_names) self.interleave_probs = [float(prob.strip()) for prob in self.interleave_probs.split(",")]
self.dataset_list: List[DatasetAttr] = [] self.dataset_list: List[DatasetAttr] = []
for i, name in enumerate(dataset_names): for i, name in enumerate(dataset_names):
@@ -119,12 +123,11 @@ class DataArguments:
dataset_sha1=dataset_info[name].get("file_sha1", None) dataset_sha1=dataset_info[name].get("file_sha1", None)
) )
dataset_attr.source_prefix = prefix_list[i]
if "columns" in dataset_info[name]: if "columns" in dataset_info[name]:
dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None) dataset_attr.prompt = dataset_info[name]["columns"].get("prompt", None)
dataset_attr.query = dataset_info[name]["columns"].get("query", None) dataset_attr.query = dataset_info[name]["columns"].get("query", None)
dataset_attr.response = dataset_info[name]["columns"].get("response", None) dataset_attr.response = dataset_info[name]["columns"].get("response", None)
dataset_attr.history = dataset_info[name]["columns"].get("history", None) dataset_attr.history = dataset_info[name]["columns"].get("history", None)
dataset_attr.system_prompt = prompt_list[i]
self.dataset_list.append(dataset_attr) self.dataset_list.append(dataset_attr)

View File

@@ -5,32 +5,36 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class FinetuningArguments: class FinetuningArguments:
""" r"""
Arguments pertaining to which techniques we are going to fine-tuning with. Arguments pertaining to which techniques we are going to fine-tuning with.
""" """
finetuning_type: Optional[Literal["none", "freeze", "lora", "full"]] = field( finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field(
default="lora", default="lora",
metadata={"help": "Which fine-tuning method to use."} metadata={"help": "Which fine-tuning method to use."}
) )
num_hidden_layers: Optional[int] = field( num_hidden_layers: Optional[int] = field(
default=32, default=32,
metadata={"help": "Number of decoder blocks in the model. \ metadata={"help": "Number of decoder blocks in the model for partial-parameter (freeze) fine-tuning. \
LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \ LLaMA choices: [\"32\", \"40\", \"60\", \"80\"], \
LLaMA-2 choices: [\"32\", \"40\", \"80\"], \ LLaMA-2 choices: [\"32\", \"40\", \"80\"], \
BLOOM choices: [\"24\", \"30\", \"70\"], \ BLOOM choices: [\"24\", \"30\", \"70\"], \
Falcon choices: [\"32\", \"60\"], \ Falcon choices: [\"32\", \"60\"], \
Baichuan choices: [\"32\", \"40\"]"} Baichuan choices: [\"32\", \"40\"] \
Qwen choices: [\"32\"], \
XVERSE choices: [\"40\"]"}
) )
num_layer_trainable: Optional[int] = field( num_layer_trainable: Optional[int] = field(
default=3, default=3,
metadata={"help": "Number of trainable layers for Freeze fine-tuning."} metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."}
) )
name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
default="mlp", default="mlp",
metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \
LLaMA & LLaMA-2 choices: [\"mlp\", \"self_attn\"], \ LLaMA choices: [\"mlp\", \"self_attn\"], \
BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \ BLOOM & Falcon choices: [\"mlp\", \"self_attention\"], \
Baichuan choices: [\"mlp\", \"self_attn\"]"} Baichuan choices: [\"mlp\", \"self_attn\"], \
Qwen choices: [\"mlp\", \"attn\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
) )
lora_rank: Optional[int] = field( lora_rank: Optional[int] = field(
default=8, default=8,
@@ -45,11 +49,25 @@ class FinetuningArguments:
metadata={"help": "Dropout rate for the LoRA fine-tuning."} metadata={"help": "Dropout rate for the LoRA fine-tuning."}
) )
lora_target: Optional[str] = field( lora_target: Optional[str] = field(
default="q_proj,v_proj", default=None,
metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \
LLaMA & LLaMA-2 choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ BLOOM & Falcon choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \
Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"]"} Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \
Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \
LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."}
)
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
ppo_score_norm: Optional[bool] = field(
default=False,
metadata={"help": "Use score normalization in PPO Training."}
)
dpo_beta: Optional[float] = field(
default=0.1,
metadata={"help": "The beta parameter for the DPO loss."}
) )
def __post_init__(self): def __post_init__(self):
@@ -63,17 +81,17 @@ class FinetuningArguments:
self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] self.trainable_layers = ["{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method."
def save_to_json(self, json_path: str): def save_to_json(self, json_path: str):
"""Saves the content of this instance in JSON format inside `json_path`.""" r"""Saves the content of this instance in JSON format inside `json_path`."""
json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
with open(json_path, "w", encoding="utf-8") as f: with open(json_path, "w", encoding="utf-8") as f:
f.write(json_string) f.write(json_string)
@classmethod @classmethod
def load_from_json(cls, json_path: str): def load_from_json(cls, json_path: str):
"""Creates an instance from the content of `json_path`.""" r"""Creates an instance from the content of `json_path`."""
with open(json_path, "r", encoding="utf-8") as f: with open(json_path, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
return cls(**json.loads(text)) return cls(**json.loads(text))

View File

@@ -4,10 +4,10 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class GeneralArguments: class GeneralArguments:
""" r"""
Arguments pertaining to which stage we are going to perform. Arguments pertaining to which stage we are going to perform.
""" """
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = field( stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field(
default="sft", default="sft",
metadata={"help": "Which stage will be performed in training."} metadata={"help": "Which stage will be performed in training."}
) )

View File

@@ -4,7 +4,7 @@ from dataclasses import asdict, dataclass, field
@dataclass @dataclass
class GeneratingArguments: class GeneratingArguments:
""" r"""
Arguments pertaining to specify the decoding parameters. Arguments pertaining to specify the decoding parameters.
""" """
do_sample: Optional[bool] = field( do_sample: Optional[bool] = field(

View File

@@ -5,7 +5,7 @@ from dataclasses import dataclass, field
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" r"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune. Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
""" """
model_name_or_path: str = field( model_name_or_path: str = field(
@@ -43,9 +43,9 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Whether to use double quantization in int4 training or not."} metadata={"help": "Whether to use double quantization in int4 training or not."}
) )
compute_dtype: Optional[torch.dtype] = field( rope_scaling: Optional[Literal["linear", "dynamic"]] = field(
default=None, default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."} metadata={"help": "Adopt scaled rotary positional embeddings."}
) )
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
@@ -55,18 +55,33 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "Path to the directory containing the checkpoints of the reward model."} metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
) )
resume_lora_training: Optional[bool] = field(
default=True,
metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
)
plot_loss: Optional[bool] = field( plot_loss: Optional[bool] = field(
default=False, default=False,
metadata={"help": "Whether to plot the training loss after fine-tuning or not."} metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
) )
hf_auth_token: Optional[str] = field(
default=None,
metadata={"help": "Auth token to log in with Hugging Face Hub."}
)
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
model_max_length: Optional[int] = field(
default=None,
metadata={"help": "Used in rope scaling. Do not specify this argument manually."}
)
def __post_init__(self): def __post_init__(self):
if self.compute_dtype is not None or self.model_max_length is not None:
raise ValueError("These arguments cannot be specified.")
if self.checkpoint_dir is not None: # support merging multiple lora weights if self.checkpoint_dir is not None: # support merging multiple lora weights
self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
if self.quantization_bit is not None: if self.quantization_bit is not None:
assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
if self.use_auth_token == True and self.hf_auth_token is not None:
from huggingface_hub.hf_api import HfFolder # lazy load
HfFolder.save_token(self.hf_auth_token)

View File

@@ -1,5 +1 @@
from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer from llmtuner.tuner.tune import export_model, run_exp
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo

View File

@@ -39,7 +39,7 @@ def init_adapter(
if finetuning_args.finetuning_type == "none" and is_trainable: if finetuning_args.finetuning_type == "none" and is_trainable:
raise ValueError("You cannot use finetuning_type=none while training.") raise ValueError("You cannot use finetuning_type=none while training.")
if finetuning_args.finetuning_type == "full": if finetuning_args.finetuning_type == "full" and is_trainable:
logger.info("Fine-tuning method: Full") logger.info("Fine-tuning method: Full")
model = model.float() model = model.float()
@@ -65,7 +65,7 @@ def init_adapter(
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
"The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead."
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights if (is_trainable and finetuning_args.resume_lora_training) or (not is_mergeable): # continually fine-tuning
checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
else: else:
checkpoints_to_merge = model_args.checkpoint_dir checkpoints_to_merge = model_args.checkpoint_dir

View File

@@ -1,18 +1,21 @@
import os import os
import math
import torch import torch
from types import MethodType
from typing import TYPE_CHECKING, Literal, Optional, Tuple from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import ( from transformers import (
AutoConfig, AutoConfig,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
BitsAndBytesConfig BitsAndBytesConfig,
PretrainedConfig,
PreTrainedModel,
PreTrainedTokenizerBase
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled from transformers.deepspeed import is_deepspeed_zero3_enabled
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.logging import reset_logging, get_logger
@@ -22,6 +25,7 @@ from llmtuner.hparams import FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from llmtuner.hparams import ModelArguments from llmtuner.hparams import ModelArguments
@@ -32,7 +36,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") require_version("trl>=0.5.0", "To fix: pip install trl>=0.5.0")
def load_model_and_tokenizer( def load_model_and_tokenizer(
@@ -40,7 +44,7 @@ def load_model_and_tokenizer(
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False, is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft" stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: ) -> Tuple[PreTrainedModel, "PreTrainedTokenizer"]:
r""" r"""
Loads pretrained model and tokenizer. Loads pretrained model and tokenizer.
@@ -50,9 +54,6 @@ def load_model_and_tokenizer(
logger.warning("Checkpoint is not found at evaluation, load the original model.") logger.warning("Checkpoint is not found at evaluation, load the original model.")
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with the LoRA method."
config_kwargs = { config_kwargs = {
"trust_remote_code": True, "trust_remote_code": True,
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
@@ -66,21 +67,67 @@ def load_model_and_tokenizer(
padding_side=model_args.padding_side, padding_side=model_args.padding_side,
**config_kwargs **config_kwargs
) )
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
tokenizer.pad_token_id = 0 # set as the <unk> token
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) if finetuning_args.finetuning_type == "full" and model_args.checkpoint_dir is not None:
is_mergeable = True model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
config = AutoConfig.from_pretrained(model_to_load, **config_kwargs)
if hasattr(config, "fp16") and hasattr(config, "bf16"): # fix Qwen config
if model_args.compute_dtype == torch.bfloat16:
setattr(config, "bf16", True)
else:
setattr(config, "fp16", True)
# Set RoPE scaling
if model_args.rope_scaling is not None:
if hasattr(config, "use_dynamic_ntk"): # for Qwen models
if is_trainable:
logger.warning("Qwen model does not support RoPE scaling in training.")
else:
setattr(config, "use_dynamic_ntk", True)
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:
if model_args.rope_scaling == "dynamic":
logger.warning(
"Dynamic NTK may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653"
)
current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length:
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning("Input length is smaller than max length. Consider increase input length.")
scaling_factor = 1.0
else:
scaling_factor = 2.0
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
logger.info("Using {} scaling strategy and setting scaling factor to {}".format(
model_args.rope_scaling, scaling_factor
))
else:
logger.warning("Current model does not support RoPE scaling.")
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library).
is_mergeable = True
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantization.")
if model_args.quantization_bit == 8: if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
config_kwargs["load_in_8bit"] = True config_kwargs["load_in_8bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True)
load_in_8bit=True,
llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4: elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
@@ -93,28 +140,26 @@ def load_model_and_tokenizer(
) )
is_mergeable = False is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if ( # Load and prepare pre-trained models (without valuehead).
model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0]
else:
model_to_load = model_args.model_name_or_path
# Load and prepare pretrained models (without valuehead).
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_to_load, model_to_load,
config=config, config=config,
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, torch_dtype=model_args.compute_dtype,
low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()),
**config_kwargs **config_kwargs
) )
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)
if not hasattr(model, "lm_head") and hasattr(model, "transformer"):
setattr(model, "lm_head", model.transformer.output_layer)
# Register auto class to save the custom code files. # Register auto class to save the custom code files.
if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}): if isinstance(config, PretrainedConfig) and "AutoConfig" in getattr(config, "auto_map", {}):
config.__class__.register_for_auto_class() config.__class__.register_for_auto_class()
@@ -127,10 +172,10 @@ def load_model_and_tokenizer(
model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model model = prepare_model_for_training(model, finetuning_args.finetuning_type) if is_trainable else model
model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) model = init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable)
if stage == "rm" or stage == "ppo": # add value head # Prepare model with valuehead for RLHF
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) if stage == "rm" or stage == "ppo":
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging() reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
if load_valuehead_params(model, model_args.checkpoint_dir[-1]): if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
@@ -140,15 +185,15 @@ def load_model_and_tokenizer(
}) })
if stage == "ppo": # load reward model if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
# Prepare model for inference
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 infer_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 # detect cuda capability
model = model.to(infer_dtype) if model_args.quantization_bit is None else model
trainable_params, all_param = count_parameters(model) trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(

View File

@@ -5,6 +5,7 @@ import datasets
import transformers import transformers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
from transformers import HfArgumentParser, Seq2SeqTrainingArguments from transformers import HfArgumentParser, Seq2SeqTrainingArguments
from transformers.trainer_utils import get_last_checkpoint
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.hparams import ( from llmtuner.hparams import (
@@ -19,7 +20,7 @@ from llmtuner.hparams import (
logger = get_logger(__name__) logger = get_logger(__name__)
def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None): def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
if args is not None: if args is not None:
return parser.parse_dict(args) return parser.parse_dict(args)
elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
@@ -32,26 +33,53 @@ def _parse_args(parser: HfArgumentParser, args: Optional[Dict[str, Any]] = None)
def parse_train_args( def parse_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def parse_infer_args( def parse_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
parser = HfArgumentParser(( parser = HfArgumentParser((
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
)) ))
return _parse_args(parser, args) return _parse_args(parser, args)
def get_train_args( def get_train_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: ) -> Tuple[
model_args, data_args, training_args, finetuning_args, general_args = parse_train_args(args) ModelArguments,
DataArguments,
Seq2SeqTrainingArguments,
FinetuningArguments,
GeneratingArguments,
GeneralArguments
]:
model_args, data_args, training_args, finetuning_args, generating_args, general_args = parse_train_args(args)
# Setup logging # Setup logging
if training_args.should_log: if training_args.should_log:
@@ -67,40 +95,61 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training() data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ if general_args.stage != "sft" and training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages." raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
assert not (training_args.do_train and training_args.predict_with_generate), \ if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
"`predict_with_generate` cannot be set as True while training." raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ if general_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type != "lora":
"Please enable `predict_with_generate` to save model predictions." raise ValueError("RM and PPO stages can only be performed with the LoRA method.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if general_args.stage in ["rm", "ppo"] and training_args.resume_from_checkpoint is not None:
"Quantization is only compatible with the LoRA method." raise ValueError("RM and PPO stages do not support `resume_from_checkpoint`.")
assert not (training_args.max_steps == -1 and data_args.streaming), \ if general_args.stage in ["ppo", "dpo"] and not training_args.do_train:
"Please specify `max_steps` in streaming mode." raise ValueError("PPO and DPO stages can only be performed at training.")
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \ if general_args.stage == "ppo" and model_args.reward_model is None:
"Streaming mode does not support evaluation currently." raise ValueError("Reward model is necessary for PPO training.")
assert not (general_args.stage == "ppo" and data_args.streaming), \ if general_args.stage == "ppo" and data_args.streaming:
"Streaming mode does not suppport PPO training currently." raise ValueError("Streaming mode does not suppport PPO training currently.")
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_train and finetuning_args.finetuning_type == "lora" and finetuning_args.lora_target is None:
raise ValueError("Please specify `lora_target` in LoRA training.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
if training_args.do_train and (not training_args.fp16): if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
logger.warning("We recommend enable fp16 mixed precision training.") logger.warning("We recommend enable mixed precision training.")
# postprocess data_args
if data_args.max_samples is not None and data_args.streaming:
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
# postprocess training_args
if ( if (
training_args.local_rank != -1 training_args.local_rank != -1
and training_args.ddp_find_unused_parameters is None and training_args.ddp_find_unused_parameters is None
@@ -109,50 +158,66 @@ def get_train_args(
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
training_args.ddp_find_unused_parameters = False training_args.ddp_find_unused_parameters = False
if data_args.max_samples is not None and data_args.streaming: if training_args.optim == "adamw_hf":
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.") training_args.optim = "adamw_torch" # suppress warning
data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming: if (
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.") training_args.resume_from_checkpoint is None
data_args.dev_ratio = 0 and training_args.do_train
and os.path.isdir(training_args.output_dir)
and not training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError("Output directory already exists and is not empty. Use `overwrite_output_dir`.")
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning if last_checkpoint is not None:
training_args.resume_from_checkpoint = last_checkpoint
logger.info(
"Resuming from checkpoint. Change `output_dir` or use `overwrite_output_dir` to avoid."
)
if model_args.quantization_bit is not None: # postprocess model_args
if training_args.fp16: if training_args.bf16:
model_args.compute_dtype = torch.float16 if not torch.cuda.is_bf16_supported():
elif training_args.bf16: raise ValueError("Current device does not support bf16 training.")
model_args.compute_dtype = torch.bfloat16 model_args.compute_dtype = torch.bfloat16
else: else:
model_args.compute_dtype = torch.float32 model_args.compute_dtype = torch.float16
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
# Log on each process the small summary: # Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, 16-bits training: {}".format( logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
training_args.local_rank, training_args.device, training_args.n_gpu, training_args.local_rank, training_args.device, training_args.n_gpu,
bool(training_args.local_rank != -1), training_args.fp16 bool(training_args.local_rank != -1), str(model_args.compute_dtype)
)) ))
logger.info(f"Training/evaluation parameters {training_args}") logger.info(f"Training/evaluation parameters {training_args}")
# Set seed before initializing model. # Set seed before initializing model.
transformers.set_seed(training_args.seed) transformers.set_seed(training_args.seed)
return model_args, data_args, training_args, finetuning_args, general_args return model_args, data_args, training_args, finetuning_args, generating_args, general_args
def get_infer_args( def get_infer_args(
args: Optional[Dict[str, Any]] = None args: Optional[Dict[str, Any]] = None
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: ) -> Tuple[
ModelArguments,
DataArguments,
FinetuningArguments,
GeneratingArguments
]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \ if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
"Quantization is only compatible with the LoRA method." raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." if len(model_args.checkpoint_dir) != 1:
else: raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \ elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
"Quantized model only accepts a single checkpoint." raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args return model_args, data_args, finetuning_args, generating_args

View File

@@ -13,26 +13,25 @@ from llmtuner.extras.logging import get_logger
from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params from llmtuner.extras.save_and_load import get_state_dict, load_trainable_params
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, Seq2SeqTrainingArguments, TrainerState
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments
logger = get_logger(__name__) logger = get_logger(__name__)
class PeftTrainer(Seq2SeqTrainer): class PeftModelMixin:
r""" r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. Patches the save and load methods in Hugging Face Trainer for PeftModel and ModelWithValueHead.
""" """
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs): def __init__(self) -> None: # for type checking
super().__init__(**kwargs) self.model: PreTrainedModel = None
self.finetuning_args = finetuning_args self.tokenizer: "PreTrainedTokenizer" = None
self._remove_log() self.args: "Seq2SeqTrainingArguments" = None
self.finetuning_args: "FinetuningArguments" = None
def _remove_log(self): self.state: "TrainerState" = None
if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): raise AssertionError("Mixin should not be initialized.")
logger.warning("Previous log file in this folder will be deleted.")
os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
r""" r"""
@@ -47,7 +46,6 @@ class PeftTrainer(Seq2SeqTrainer):
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = unwrap_model(self.model) model = unwrap_model(self.model)
if isinstance(model, PreTrainedModelWrapper): if isinstance(model, PreTrainedModelWrapper):
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200 # Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict() model_state_dict = state_dict or model.state_dict()
@@ -68,7 +66,10 @@ class PeftTrainer(Seq2SeqTrainer):
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None: if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir) try:
self.tokenizer.save_pretrained(output_dir)
except:
logger.warning("Cannot save tokenizer, copy the files manually.")
with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
f.write(self.args.to_json_string() + "\n") f.write(self.args.to_json_string() + "\n")
@@ -94,3 +95,13 @@ class PeftTrainer(Seq2SeqTrainer):
model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) model.load_adapter(self.state.best_model_checkpoint, model.active_adapter)
else: # freeze/full-tuning else: # freeze/full-tuning
load_trainable_params(model, self.state.best_model_checkpoint) load_trainable_params(model, self.state.best_model_checkpoint)
class PeftTrainer(PeftModelMixin, Seq2SeqTrainer):
r"""
Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
"""
def __init__(self, finetuning_args: "FinetuningArguments", **kwargs):
Seq2SeqTrainer.__init__(self, **kwargs)
self.finetuning_args = finetuning_args

View File

@@ -0,0 +1 @@
from llmtuner.tuner.dpo.workflow import run_dpo

View File

@@ -0,0 +1,51 @@
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple
from transformers import DataCollatorForSeq2Seq
@dataclass
class DPODataCollatorWithPadding(DataCollatorForSeq2Seq):
r"""
Data collator for pairwise data.
"""
def _pad_labels(self, batch: torch.Tensor, positions: List[Tuple[int, int]]) -> torch.Tensor:
padded_labels = []
for feature, (prompt_len, answer_len) in zip(batch, positions):
if self.tokenizer.padding_side == "left":
start, end = feature.size(0) - answer_len, feature.size(0)
else:
start, end = prompt_len, answer_len
padded_tensor = self.label_pad_token_id * torch.ones_like(feature)
padded_tensor[start:end] = feature[start:end]
padded_labels.append(padded_tensor)
return torch.stack(padded_labels, dim=0).contiguous() # in contiguous memory
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
r"""
Pads batched data to the longest sequence in the batch.
We generate 2 * n examples where the first n examples represent chosen examples and
the last n examples represent rejected examples.
"""
concatenated_features = []
label_positions = []
for key in ("chosen_ids", "rejected_ids"):
for feature in features:
prompt_len, answer_len = len(feature["prompt_ids"]), len(feature[key])
concatenated_features.append({
"input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (prompt_len + answer_len)
})
label_positions.append((prompt_len, answer_len))
batch = self.tokenizer.pad(
concatenated_features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)
batch["labels"] = self._pad_labels(batch["input_ids"], label_positions)
return batch

View File

@@ -0,0 +1,77 @@
import torch
from collections import defaultdict
from peft import PeftModel
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
from transformers import BatchEncoding, Trainer
from trl import DPOTrainer
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.tuner.core.trainer import PeftModelMixin
if TYPE_CHECKING:
from transformers import PreTrainedModel
from llmtuner.hparams import FinetuningArguments, GeneratingArguments
class DPOPeftTrainer(PeftModelMixin, DPOTrainer):
def __init__(
self,
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
ref_model: Optional[Union["PreTrainedModel", torch.nn.Module]] = None,
**kwargs
):
self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.ref_model = ref_model
self.use_dpo_data_collator = True # hack to avoid warning
self.label_pad_token_id = IGNORE_INDEX
self.padding_value = 0
self.beta = finetuning_args.dpo_beta
self._stored_metrics = defaultdict(lambda: defaultdict(list))
Trainer.__init__(self, **kwargs)
if not hasattr(self, "accelerator"):
raise AttributeError("Please update `transformers`.")
if ref_model is not None:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_forward(
self,
model: Optional[torch.nn.Module] = None,
batch: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
batch_copied = BatchEncoding({k: v.detach().clone() for k, v in batch.items()}) # avoid error
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_disable()
if model is None and isinstance(unwrapped_model, PeftModel): # peft model has no ref_model
with unwrapped_model.disable_adapter():
all_logits = self.model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
else:
all_logits = model(
input_ids=batch_copied["input_ids"],
attention_mask=batch_copied["attention_mask"],
return_dict=True
).logits.to(torch.float32)
if not torch.is_grad_enabled():
unwrapped_model.gradient_checkpointing_enable()
all_logps = self._get_batch_logps(
all_logits,
batch["labels"],
average_log_prob=False
)
batch_size = batch["input_ids"].size(0) // 2
chosen_logps, rejected_logps = all_logps.split(batch_size, dim=0)
chosen_logits, rejected_logits = all_logits.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits

View File

@@ -0,0 +1,59 @@
# Inspired by: https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
from copy import deepcopy
from peft import PeftModel
from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding
from llmtuner.tuner.dpo.trainer import DPOPeftTrainer
if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_dpo(
model_args: "ModelArguments",
data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
training_args.remove_unused_columns = False # important for pairwise dataset
ref_model = deepcopy(model) if not isinstance(model, PeftModel) else None
# Initialize our Trainer
trainer = DPOPeftTrainer(
finetuning_args=finetuning_args,
generating_args=generating_args,
ref_model=ref_model,
model=model,
args=training_args,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args, training_args)
)
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()
trainer.save_model()
if trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

View File

@@ -2,24 +2,23 @@ import os
import math import math
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from typing import TYPE_CHECKING, Callable, Dict, List, Optional from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
from transformers import TrainerState, TrainerControl from transformers import TrainerState, TrainerControl
from transformers.modeling_utils import PreTrainedModel
from trl import PPOTrainer from trl import PPOTrainer
from trl.core import LengthSampler from trl.core import LengthSampler, PPODecorators, logprobs_from_logits
from llmtuner.extras.logging import get_logger from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
from llmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model from llmtuner.tuner.ppo.utils import replace_model
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments from transformers import Seq2SeqTrainingArguments
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.hparams import FinetuningArguments from llmtuner.hparams import FinetuningArguments, GeneratingArguments
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -34,17 +33,19 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
self, self,
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments",
callbacks: List["LogCallback"], callbacks: List["LogCallback"],
compute_dtype: torch.dtype,
**kwargs **kwargs
): ):
PPOTrainer.__init__(self, **kwargs) PPOTrainer.__init__(self, **kwargs)
self.args = training_args self.args = training_args
self.finetuning_args = finetuning_args self.finetuning_args = finetuning_args
self.generating_args = generating_args
self.log_callback = callbacks[0] self.log_callback = callbacks[0]
self.compute_dtype = compute_dtype
self.state = TrainerState() self.state = TrainerState()
self.control = TrainerControl() self.control = TrainerControl()
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
self._remove_log()
def ppo_train(self, max_target_length: int) -> None: def ppo_train(self, max_target_length: int) -> None:
r""" r"""
@@ -74,16 +75,13 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}")
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = self.generating_args.to_dict()
"top_k": 0.0, gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
"top_p": 1.0, gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
"do_sample": True, gen_kwargs["logits_processor"] = get_logits_processor()
"pad_token_id": self.tokenizer.pad_token_id,
"eos_token_id": self.tokenizer.eos_token_id,
"logits_processor": get_logits_processor()
}
length_sampler = LengthSampler(max_target_length // 2, max_target_length) length_sampler = LengthSampler(max_target_length // 2, max_target_length)
unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0 steps_trained = 0
@@ -91,51 +89,38 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
reward_meter = AverageMeter() reward_meter = AverageMeter()
self.log_callback.on_train_begin(self.args, self.state, self.control) self.log_callback.on_train_begin(self.args, self.state, self.control)
for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()):
batch = next(dataiter) batch = next(dataiter)
steps_trained += 1 steps_trained += 1
# Cast to inference mode
unwrapped_model.gradient_checkpointing_disable() unwrapped_model.gradient_checkpointing_disable()
unwrapped_model.config.use_cache = True unwrapped_model.config.use_cache = True
# Get responses # Get inputs
query_tensors = batch["input_ids"] queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs) rewards = self.get_rewards(queries, responses, unwrapped_model)
queries, responses = [], [] # Cast to training mode
for i in range(len(query_tensors)):
query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query_tensors[i, query_length:]) # remove padding from left
responses.append(response_tensors[i, :response_length]) # remove padding from right
# Compute rewards
replace_model(unwrapped_model, target="reward")
with torch.no_grad():
_, _, values = self.model(
**self.prepare_model_inputs(queries, responses),
output_hidden_states=True,
return_dict=True
)
rewards = [reward for reward in values[:, -1].to(torch.float32)] # use float32 type
replace_model(unwrapped_model, target="default")
# Run PPO step
unwrapped_model.gradient_checkpointing_enable() unwrapped_model.gradient_checkpointing_enable()
unwrapped_model.config.use_cache = False unwrapped_model.config.use_cache = False
stats = self.step(queries, responses, rewards)
# Run PPO step
stats = self.step(queries, responses, rewards)
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: self.state.global_step += 1
self.log_callback.on_step_end(self.args, self.state, self.control)
if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0:
logs = dict( logs = dict(
loss=round(loss_meter.avg, 4), loss=round(loss_meter.avg, 4),
reward=round(reward_meter.avg, 4), reward=round(reward_meter.avg, 4),
learning_rate=stats["ppo/learning_rate"], learning_rate=stats["ppo/learning_rate"],
epoch=round(step / len_dataloader, 2) epoch=round(step / len_dataloader, 2)
) )
print(logs) tqdm.write(str(logs))
logs["step"] = step logs["step"] = step
self.state.log_history.append(logs) self.state.log_history.append(logs)
self.log_callback.on_log(self.args, self.state, self.control) self.log_callback.on_log(self.args, self.state, self.control)
@@ -152,38 +137,124 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
dataiter = iter(self.dataloader) dataiter = iter(self.dataloader)
steps_trained = 0 steps_trained = 0
self.log_callback.on_train_end(self.args, self.state, self.control)
@torch.no_grad() @torch.no_grad()
def generate( def get_inputs(
self, self,
inputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor],
length_sampler: Optional[Callable] = None, length_sampler: Optional[Callable] = None,
return_prompt: Optional[bool] = True,
**generation_kwargs **generation_kwargs
) -> torch.Tensor: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
r""" r"""
Generates model's responses given queries. Generates model's responses given queries.
Subclass and override to inject custom behavior.
""" """
self.model, layer_norm_params = cast_layernorm_dtype(self.model)
if length_sampler is not None: if length_sampler is not None:
generation_kwargs["max_new_tokens"] = length_sampler() generation_kwargs["max_new_tokens"] = length_sampler()
unwrapped_model = self.accelerator.unwrap_model(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
response: torch.Tensor = unwrapped_model.generate(**batch, **generation_kwargs)
response = unwrapped_model.generate(**inputs, **generation_kwargs)
# Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
# Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
if unwrapped_model.pretrained_model.generation_config._from_model_config: if unwrapped_model.pretrained_model.generation_config._from_model_config:
unwrapped_model.pretrained_model.generation_config._from_model_config = False unwrapped_model.pretrained_model.generation_config._from_model_config = False
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) queries, responses = [], []
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right
if not return_prompt and not self.is_encoder_decoder: return queries, responses
return response[:, inputs["input_ids"].size(1):]
return response @torch.no_grad()
def get_rewards(
self,
queries: List[torch.Tensor],
responses: List[torch.Tensor],
unwrapped_model: "AutoModelForCausalLMWithValueHead"
) -> List[torch.Tensor]:
r"""
Computes scores using given reward model.
"""
replace_model(unwrapped_model, target="reward")
batch = self.prepare_model_inputs(queries, responses)
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
_, _, values = self.model(**batch, output_hidden_states=True, return_dict=True)
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type
replace_model(unwrapped_model, target="default")
return rewards
@PPODecorators.empty_cuda_cache()
def batched_forward_pass(
self,
model: "AutoModelForCausalLMWithValueHead",
queries: torch.Tensor,
responses: torch.Tensor,
model_inputs: dict,
return_logits: Optional[bool] = False
):
r"""
Calculates model outputs in multiple batches.
Subclass and override to inject custom behavior.
"""
bs = len(queries)
fbs = self.config.mini_batch_size
all_logprobs = []
all_logits = []
all_masks = []
all_values = []
for i in range(math.ceil(bs / fbs)):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
input_ids = input_kwargs["input_ids"]
attention_mask = input_kwargs["attention_mask"]
with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16
logits, _, values = model(**input_kwargs)
if values.size(0) != input_ids.size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
masks = torch.zeros_like(attention_mask)
masks[:, :-1] = attention_mask[:, 1:]
for j in range(len(query_batch)):
start = len(query_batch[j]) - 1
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
masks[j, :start] = 0
masks[j, end:] = 0
if return_logits:
all_logits.append(logits)
else:
del logits
all_values.append(values)
all_logprobs.append(logprobs)
all_masks.append(masks)
return (
torch.cat(all_logprobs),
torch.cat(all_logits)[:, :-1] if return_logits else None,
torch.cat(all_values)[:, :-1],
torch.cat(all_masks)[:, :-1],
)
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""

View File

@@ -1,7 +1,4 @@
import torch from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Tuple
from llmtuner.extras.constants import LAYERNORM_NAMES
if TYPE_CHECKING: if TYPE_CHECKING:
from trl import AutoModelForCausalLMWithValueHead from trl import AutoModelForCausalLMWithValueHead
@@ -18,22 +15,3 @@ def replace_model(model: "AutoModelForCausalLMWithValueHead", target: Literal["d
"summary.weight": getattr(model, "{}_head_weight".format(target)), "summary.weight": getattr(model, "{}_head_weight".format(target)),
"summary.bias": getattr(model, "{}_head_bias".format(target)) "summary.bias": getattr(model, "{}_head_bias".format(target))
}) })
def cast_layernorm_dtype(
model: "AutoModelForCausalLMWithValueHead",
layer_norm_names: List[str] = LAYERNORM_NAMES,
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
) -> Tuple["AutoModelForCausalLMWithValueHead", Dict[str, torch.Tensor]]:
layer_norm_state_dict = {}
for name, param in model.named_parameters():
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
if layer_norm_params is not None:
param.data = layer_norm_params[name] # restore float32 weights
else:
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
param.data = param.data.to(torch.float16)
return model, layer_norm_state_dict

View File

@@ -1,23 +1,21 @@
# Inspired by: # Inspired by: https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py
# https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
import math import math
from typing import TYPE_CHECKING
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.ppo.trainer import PPOPeftTrainer from llmtuner.tuner.ppo.trainer import PPOPeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_ppo( def run_ppo(
@@ -25,7 +23,8 @@ def run_ppo(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
@@ -39,24 +38,35 @@ def run_ppo(
batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
gradient_accumulation_steps=training_args.gradient_accumulation_steps, gradient_accumulation_steps=training_args.gradient_accumulation_steps,
ppo_epochs=1, ppo_epochs=1,
max_grad_norm=training_args.max_grad_norm max_grad_norm=training_args.max_grad_norm,
seed=training_args.seed,
optimize_cuda_cache=True
) )
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) if finetuning_args.ppo_score_norm:
total_train_batch_size = \ require_version("trl>=0.5.1.dev0", "To fix: pip install git+https://github.com/huggingface/trl.git")
ppo_config.use_score_scaling = True
ppo_config.use_score_norm = True
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = (
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
)
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
training_args.lr_scheduler_type, training_args.lr_scheduler_type,
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps, num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) num_training_steps=num_training_steps
) )
# Initialize our Trainer # Initialize our Trainer
ppo_trainer = PPOPeftTrainer( ppo_trainer = PPOPeftTrainer(
training_args=training_args, training_args=training_args,
finetuning_args=finetuning_args, finetuning_args=finetuning_args,
generating_args=generating_args,
callbacks=callbacks, callbacks=callbacks,
compute_dtype=model_args.compute_dtype,
config=ppo_config, config=ppo_config,
model=model, model=model,
ref_model=None, ref_model=None,
@@ -67,8 +77,10 @@ def run_ppo(
lr_scheduler=lr_scheduler lr_scheduler=lr_scheduler
) )
ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) # Training
ppo_trainer.save_model() if training_args.do_train:
ppo_trainer.save_state() # must be after save_model ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
if ppo_trainer.is_world_process_zero() and model_args.plot_loss: ppo_trainer.save_model()
plot_loss(training_args.output_dir, keys=["loss", "reward"]) ppo_trainer.save_state() # must be called after save_model to have a folder
if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
plot_loss(training_args.output_dir, keys=["loss", "reward"])

View File

@@ -2,11 +2,9 @@
import math import math
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForLanguageModeling
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.core.trainer import PeftTrainer from llmtuner.tuner.core.trainer import PeftTrainer
@@ -21,15 +19,12 @@ def run_pt(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="pt")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="pt")
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
)
# Initialize our Trainer # Initialize our Trainer
trainer = PeftTrainer( trainer = PeftTrainer(
@@ -39,12 +34,12 @@ def run_pt(
tokenizer=tokenizer, tokenizer=tokenizer,
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train() train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()
@@ -61,6 +56,5 @@ def run_pt(
perplexity = float("inf") perplexity = float("inf")
metrics["perplexity"] = perplexity metrics["perplexity"] = perplexity
trainer.log_metrics("eval", metrics) trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("eval", metrics)

View File

@@ -1,8 +1,10 @@
import torch import torch
from dataclasses import dataclass
from typing import Any, Dict, Sequence from typing import Any, Dict, Sequence
from transformers import DataCollatorWithPadding from transformers import DataCollatorWithPadding
@dataclass
class PairwiseDataCollatorWithPadding(DataCollatorWithPadding): class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
r""" r"""
Data collator for pairwise data. Data collator for pairwise data.
@@ -16,7 +18,10 @@ class PairwiseDataCollatorWithPadding(DataCollatorWithPadding):
the last n examples represent rejected examples. the last n examples represent rejected examples.
""" """
features = [ features = [
{"input_ids": feature[key], "attention_mask": [1] * len(feature[key])} {
for key in ("accept_ids", "reject_ids") for feature in features "input_ids": feature["prompt_ids"] + feature[key],
"attention_mask": [1] * (len(feature["prompt_ids"]) + len(feature[key]))
}
for key in ("chosen_ids", "rejected_ids") for feature in features
] ]
return super().__call__(features) return super().__call__(features)

View File

@@ -42,6 +42,8 @@ class PairwisePeftTrainer(PeftTrainer):
""" """
batch_size = inputs["input_ids"].size(0) // 2 batch_size = inputs["input_ids"].size(0) // 2
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0) r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss return (loss, [loss, r_accept, r_reject]) if return_outputs else loss

View File

@@ -5,7 +5,6 @@
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.rm.metric import compute_accuracy from llmtuner.tuner.rm.metric import compute_accuracy
@@ -22,7 +21,7 @@ def run_rm(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
@@ -40,7 +39,7 @@ def run_rm(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=compute_accuracy, compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Training # Training

View File

@@ -25,7 +25,7 @@ class ComputeMetrics:
Uses the model predictions to compute metrics. Uses the model predictions to compute metrics.
""" """
preds, labels = eval_preds preds, labels = eval_preds
score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
@@ -49,6 +49,5 @@ class ComputeMetrics:
bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
score_dict["bleu-4"].append(round(bleu_score * 100, 4)) score_dict["bleu-4"].append(round(bleu_score * 100, 4))
score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
return {k: float(np.mean(v)) for k, v in score_dict.items()} return {k: float(np.mean(v)) for k, v in score_dict.items()}

View File

@@ -50,11 +50,12 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = ( if generated_tokens is not None:
generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None generated_tokens[:, :max(prompt_len, label_len)] = (
) self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
)
return (loss, generated_tokens, labels) return loss, generated_tokens, labels
def _pad_tensors_to_target_len( def _pad_tensors_to_target_len(
self, self,
@@ -72,14 +73,11 @@ class Seq2SeqPeftTrainer(PeftTrainer):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
pad_token_id = self.tokenizer.pad_token_id pad_token_id = self.tokenizer.pad_token_id
else: else:
if self.model.config.pad_token_id is not None: raise ValueError("PAD token is required.")
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model.")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor) padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor return padded_tensor.contiguous() # in contiguous memory
def save_predictions( def save_predictions(
self, self,

View File

@@ -4,7 +4,6 @@ from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorForSeq2Seq
from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss from llmtuner.extras.ploting import plot_loss
@@ -14,7 +13,7 @@ from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import Seq2SeqTrainingArguments, TrainerCallback from transformers import Seq2SeqTrainingArguments, TrainerCallback
from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
def run_sft( def run_sft(
@@ -22,7 +21,8 @@ def run_sft(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = [LogCallback()] generating_args: "GeneratingArguments",
callbacks: Optional[List["TrainerCallback"]] = None
): ):
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
@@ -47,21 +47,18 @@ def run_sft(
data_collator=data_collator, data_collator=data_collator,
callbacks=callbacks, callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train) **split_dataset(dataset, data_args, training_args)
) )
# Keyword arguments for `model.generate` # Keyword arguments for `model.generate`
gen_kwargs = { gen_kwargs = generating_args.to_dict()
"do_sample": True, gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
"top_p": 0.7, gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
"max_new_tokens": data_args.max_target_length + 1, gen_kwargs["logits_processor"] = get_logits_processor()
"temperature": 0.95,
"logits_processor": get_logits_processor()
}
# Training # Training
if training_args.do_train: if training_args.do_train:
train_result = trainer.train() train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
trainer.log_metrics("train", train_result.metrics) trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics)
trainer.save_state() trainer.save_state()

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.logging import get_logger
from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer
from llmtuner.tuner.pt import run_pt
from llmtuner.tuner.sft import run_sft
from llmtuner.tuner.rm import run_rm
from llmtuner.tuner.ppo import run_ppo
from llmtuner.tuner.dpo import run_dpo
if TYPE_CHECKING:
from transformers import TrainerCallback
logger = get_logger(__name__)
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["TrainerCallback"]] = None):
model_args, data_args, training_args, finetuning_args, generating_args, general_args = get_train_args(args)
callbacks = [LogCallback()] if callbacks is None else callbacks
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args, callbacks)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
elif general_args.stage == "dpo":
run_dpo(model_args, data_args, training_args, finetuning_args, callbacks)
else:
raise ValueError("Unknown task.")
def export_model(args: Optional[Dict[str, Any]] = None, max_shard_size: Optional[str] = "10GB"):
model_args, _, training_args, finetuning_args, _, _ = get_train_args(args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size=max_shard_size)
try:
tokenizer.save_pretrained(training_args.output_dir)
except:
logger.warning("Cannot save tokenizer, please copy the files manually.")
if __name__ == "__main__":
run_exp()

View File

@@ -0,0 +1 @@
from llmtuner.webui.interface import create_ui, create_web_demo

View File

@@ -1,22 +1,22 @@
import os import os
from typing import List, Tuple from typing import Any, Dict, List, Optional, Tuple
from llmtuner.chat.stream_chat import ChatModel from llmtuner.chat.stream_chat import ChatModel
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.hparams import GeneratingArguments from llmtuner.hparams import GeneratingArguments
from llmtuner.tuner import get_infer_args
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
class WebChatModel(ChatModel): class WebChatModel(ChatModel):
def __init__(self, *args): def __init__(self, args: Optional[Dict[str, Any]] = None, lazy_init: Optional[bool] = True) -> None:
self.model = None if lazy_init:
self.tokenizer = None self.model = None
self.generating_args = GeneratingArguments() self.tokenizer = None
if len(args) != 0: self.generating_args = GeneratingArguments()
super().__init__(*args) else:
super().__init__(args)
def load_model( def load_model(
self, self,
@@ -26,7 +26,7 @@ class WebChatModel(ChatModel):
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str system_prompt: str
): ):
if self.model is not None: if self.model is not None:
yield ALERTS["err_exists"][lang] yield ALERTS["err_exists"][lang]
@@ -53,11 +53,11 @@ class WebChatModel(ChatModel):
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )
super().__init__(*get_infer_args(args)) super().__init__(args)
yield ALERTS["info_loaded"][lang] yield ALERTS["info_loaded"][lang]
@@ -73,7 +73,7 @@ class WebChatModel(ChatModel):
chatbot: List[Tuple[str, str]], chatbot: List[Tuple[str, str]],
query: str, query: str,
history: List[Tuple[str, str]], history: List[Tuple[str, str]],
prefix: str, system: str,
max_new_tokens: int, max_new_tokens: int,
top_p: float, top_p: float,
temperature: float temperature: float
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
chatbot.append([query, ""]) chatbot.append([query, ""])
response = "" response = ""
for new_text in self.stream_chat( for new_text in self.stream_chat(
query, history, prefix, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature query, history, system, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
): ):
response += new_text response += new_text
response = self.postprocess(response) response = self.postprocess(response)

View File

@@ -6,7 +6,7 @@ import gradio as gr
from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
from llmtuner.extras.constants import SUPPORTED_MODELS from llmtuner.extras.constants import DEFAULT_TEMPLATE, SUPPORTED_MODELS
DEFAULT_CACHE_DIR = "cache" DEFAULT_CACHE_DIR = "cache"
@@ -29,14 +29,16 @@ def load_config() -> Dict[str, Any]:
with open(get_config_path(), "r", encoding="utf-8") as f: with open(get_config_path(), "r", encoding="utf-8") as f:
return json.load(f) return json.load(f)
except: except:
return {"last_model": "", "path_dict": {}} return {"lang": "", "last_model": "", "path_dict": {}}
def save_config(model_name: str, model_path: str) -> None: def save_config(lang: str, model_name: str, model_path: str) -> None:
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
user_config = load_config() user_config = load_config()
user_config["last_model"] = model_name user_config["lang"] = lang or user_config["lang"]
user_config["path_dict"][model_name] = model_path if model_name:
user_config["last_model"] = model_name
user_config["path_dict"][model_name] = model_path
with open(get_config_path(), "w", encoding="utf-8") as f: with open(get_config_path(), "w", encoding="utf-8") as f:
json.dump(user_config, f, indent=2, ensure_ascii=False) json.dump(user_config, f, indent=2, ensure_ascii=False)
@@ -46,6 +48,12 @@ def get_model_path(model_name: str) -> str:
return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
def get_template(model_name: str) -> str:
if model_name.endswith("Chat") and model_name.split("-")[0] in DEFAULT_TEMPLATE:
return DEFAULT_TEMPLATE[model_name.split("-")[0]]
return "default"
def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
checkpoints = [] checkpoints = []
save_dir = os.path.join(get_save_dir(model_name), finetuning_type) save_dir = os.path.join(get_save_dir(model_name), finetuning_type)

View File

@@ -1,5 +1,6 @@
from llmtuner.webui.components.top import create_top from llmtuner.webui.components.top import create_top
from llmtuner.webui.components.sft import create_sft_tab from llmtuner.webui.components.train import create_train_tab
from llmtuner.webui.components.eval import create_eval_tab from llmtuner.webui.components.eval import create_eval_tab
from llmtuner.webui.components.infer import create_infer_tab from llmtuner.webui.components.infer import create_infer_tab
from llmtuner.webui.components.export import create_export_tab from llmtuner.webui.components.export import create_export_tab
from llmtuner.webui.components.chatbot import create_chat_box

View File

@@ -17,7 +17,7 @@ def create_chat_box(
with gr.Row(): with gr.Row():
with gr.Column(scale=4): with gr.Column(scale=4):
prefix = gr.Textbox(show_label=False) system = gr.Textbox(show_label=False)
query = gr.Textbox(show_label=False, lines=8) query = gr.Textbox(show_label=False, lines=8)
submit_btn = gr.Button(variant="primary") submit_btn = gr.Button(variant="primary")
@@ -31,7 +31,7 @@ def create_chat_box(
submit_btn.click( submit_btn.click(
chat_model.predict, chat_model.predict,
[chatbot, query, history, prefix, max_new_tokens, top_p, temperature], [chatbot, query, history, system, max_new_tokens, top_p, temperature],
[chatbot, history], [chatbot, history],
show_progress=True show_progress=True
).then( ).then(
@@ -41,7 +41,7 @@ def create_chat_box(
clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
return chat_box, chatbot, history, dict( return chat_box, chatbot, history, dict(
prefix=prefix, system=system,
query=query, query=query,
submit_btn=submit_btn, submit_btn=submit_btn,
clear_btn=clear_btn, clear_btn=clear_btn,

View File

@@ -16,6 +16,6 @@ def create_preview_box() -> Tuple["Block", "Component", "Component", "Component"
close_btn = gr.Button() close_btn = gr.Button()
close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box], queue=False)
return preview_box, preview_count, preview_samples, close_btn return preview_box, preview_count, preview_samples, close_btn

View File

@@ -14,13 +14,18 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
with gr.Row(): with gr.Row():
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row(): with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -30,38 +35,46 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
predict = gr.Checkbox(value=True) predict = gr.Checkbox(value=True)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
start_btn.click( input_components = [
runner.run_eval, top_elems["lang"],
[ top_elems["model_name"],
top_elems["lang"], top_elems["checkpoints"],
top_elems["model_name"], top_elems["finetuning_type"],
top_elems["checkpoints"], top_elems["quantization_bit"],
top_elems["finetuning_type"], top_elems["template"],
top_elems["quantization_bit"], top_elems["system_prompt"],
top_elems["template"], dataset_dir,
top_elems["source_prefix"], dataset,
dataset_dir, max_source_length,
dataset, max_target_length,
max_source_length, max_samples,
max_target_length, batch_size,
max_samples, predict
batch_size, ]
predict
], output_components = [
[output_box] output_box,
) process_bar
]
cmd_preview_btn.click(runner.preview_eval, input_components, output_components)
start_btn.click(runner.run_eval, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
return dict( return dict(
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@@ -70,6 +83,7 @@ def create_eval_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict
max_samples=max_samples, max_samples=max_samples,
batch_size=batch_size, batch_size=batch_size,
predict=predict, predict=predict,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_box=output_box output_box=output_box

View File

@@ -1,7 +1,7 @@
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING, Dict
import gradio as gr import gradio as gr
from llmtuner.webui.utils import export_model from llmtuner.webui.utils import save_model
if TYPE_CHECKING: if TYPE_CHECKING:
from gradio.components import Component from gradio.components import Component
@@ -16,12 +16,13 @@ def create_export_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component
info_box = gr.Textbox(show_label=False, interactive=False) info_box = gr.Textbox(show_label=False, interactive=False)
export_btn.click( export_btn.click(
export_model, save_model,
[ [
top_elems["lang"], top_elems["lang"],
top_elems["model_name"], top_elems["model_name"],
top_elems["checkpoints"], top_elems["checkpoints"],
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["template"],
max_shard_size, max_shard_size,
save_dir save_dir
], ],

View File

@@ -28,7 +28,7 @@ def create_infer_tab(top_elems: Dict[str, "Component"]) -> Dict[str, "Component"
top_elems["finetuning_type"], top_elems["finetuning_type"],
top_elems["quantization_bit"], top_elems["quantization_bit"],
top_elems["template"], top_elems["template"],
top_elems["source_prefix"] top_elems["system_prompt"]
], ],
[info_box] [info_box]
).then( ).then(

View File

@@ -4,7 +4,7 @@ import gradio as gr
from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS from llmtuner.extras.constants import METHODS, SUPPORTED_MODELS
from llmtuner.extras.template import templates from llmtuner.extras.template import templates
from llmtuner.webui.common import list_checkpoint, get_model_path, save_config from llmtuner.webui.common import list_checkpoint, get_model_path, get_template, save_config
from llmtuner.webui.utils import can_quantize from llmtuner.webui.utils import can_quantize
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -15,27 +15,32 @@ def create_top() -> Dict[str, "Component"]:
available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
with gr.Row(): with gr.Row():
lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1) lang = gr.Dropdown(choices=["en", "zh"], scale=1)
model_name = gr.Dropdown(choices=available_models, scale=3) model_name = gr.Dropdown(choices=available_models, scale=3)
model_path = gr.Textbox(scale=3) model_path = gr.Textbox(scale=3)
with gr.Row(): with gr.Row():
finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
checkpoints = gr.Dropdown(multiselect=True, scale=5) checkpoints = gr.Dropdown(multiselect=True, scale=5)
refresh_btn = gr.Button(scale=1) refresh_btn = gr.Button(scale=1)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
quantization_bit = gr.Dropdown([8, 4], scale=1) quantization_bit = gr.Dropdown(choices=["None", "8", "4"], value="None", scale=1)
template = gr.Dropdown(value="default", choices=list(templates.keys()), scale=1) template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=1)
source_prefix = gr.Textbox(scale=2) system_prompt = gr.Textbox(scale=2)
lang.change(save_config, [lang, model_name, model_path])
model_name.change( model_name.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
).then( ).then(
get_model_path, [model_name], [model_path] get_model_path, [model_name], [model_path]
).then(
get_template, [model_name], [template]
) # do not save config since the below line will save ) # do not save config since the below line will save
model_path.change(save_config, [model_name, model_path])
model_path.change(save_config, [lang, model_name, model_path])
finetuning_type.change( finetuning_type.change(
list_checkpoint, [model_name, finetuning_type], [checkpoints] list_checkpoint, [model_name, finetuning_type], [checkpoints]
@@ -43,7 +48,9 @@ def create_top() -> Dict[str, "Component"]:
can_quantize, [finetuning_type], [quantization_bit] can_quantize, [finetuning_type], [quantization_bit]
) )
refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) refresh_btn.click(
list_checkpoint, [model_name, finetuning_type], [checkpoints], queue=False
)
return dict( return dict(
lang=lang, lang=lang,
@@ -55,5 +62,5 @@ def create_top() -> Dict[str, "Component"]:
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
quantization_bit=quantization_bit, quantization_bit=quantization_bit,
template=template, template=template,
source_prefix=source_prefix system_prompt=system_prompt
) )

View File

@@ -3,7 +3,8 @@ from transformers.trainer_utils import SchedulerType
import gradio as gr import gradio as gr
from llmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR from llmtuner.extras.constants import STAGES
from llmtuner.webui.common import list_checkpoint, list_dataset, DEFAULT_DATA_DIR
from llmtuner.webui.components.data import create_preview_box from llmtuner.webui.components.data import create_preview_box
from llmtuner.webui.utils import can_preview, get_preview, gen_plot from llmtuner.webui.utils import can_preview, get_preview, gen_plot
@@ -12,17 +13,23 @@ if TYPE_CHECKING:
from llmtuner.webui.runner import Runner from llmtuner.webui.runner import Runner
def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]: def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[str, "Component"]:
with gr.Row(): with gr.Row():
training_stage = gr.Dropdown(choices=STAGES, value=STAGES[0], scale=2)
dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
dataset = gr.Dropdown(multiselect=True, scale=4) dataset = gr.Dropdown(multiselect=True, scale=4)
preview_btn = gr.Button(interactive=False, scale=1) data_preview_btn = gr.Button(interactive=False, scale=1)
preview_box, preview_count, preview_samples, close_btn = create_preview_box() preview_box, preview_count, preview_samples, close_btn = create_preview_box()
dataset_dir.change(list_dataset, [dataset_dir], [dataset]) dataset_dir.change(list_dataset, [dataset_dir], [dataset])
dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) dataset.change(can_preview, [dataset_dir, dataset], [data_preview_btn])
preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) data_preview_btn.click(
get_preview,
[dataset_dir, dataset],
[preview_count, preview_samples, preview_box],
queue=False
)
with gr.Row(): with gr.Row():
max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
@@ -35,10 +42,10 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
lr_scheduler_type = gr.Dropdown( lr_scheduler_type = gr.Dropdown(
value="cosine", choices=[scheduler.value for scheduler in SchedulerType] choices=[scheduler.value for scheduler in SchedulerType], value="cosine"
) )
max_grad_norm = gr.Textbox(value="1.0") max_grad_norm = gr.Textbox(value="1.0")
dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) val_size = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
with gr.Accordion(label="Advanced config", open=False) as advanced_tab: with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
with gr.Row(): with gr.Row():
@@ -46,20 +53,40 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1) lora_dropout = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=1)
lora_target = gr.Textbox(scale=2) lora_target = gr.Textbox(scale=2)
resume_lora_training = gr.Checkbox(value=True, scale=1)
with gr.Accordion(label="RLHF config", open=False) as rlhf_tab:
with gr.Row():
dpo_beta = gr.Slider(value=0.1, minimum=0, maximum=1, step=0.01, scale=2)
reward_model = gr.Dropdown(scale=2)
refresh_btn = gr.Button(scale=1)
refresh_btn.click(
list_checkpoint,
[top_elems["model_name"], top_elems["finetuning_type"]],
[reward_model],
queue=False
)
with gr.Row(): with gr.Row():
cmd_preview_btn = gr.Button()
start_btn = gr.Button() start_btn = gr.Button()
stop_btn = gr.Button() stop_btn = gr.Button()
with gr.Row(): with gr.Row():
with gr.Column(scale=3): with gr.Column(scale=3):
output_dir = gr.Textbox() with gr.Row():
output_dir = gr.Textbox()
with gr.Row():
process_bar = gr.Slider(visible=False, interactive=False)
with gr.Box(): with gr.Box():
output_box = gr.Markdown() output_box = gr.Markdown()
@@ -67,49 +94,59 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
with gr.Column(scale=1): with gr.Column(scale=1):
loss_viewer = gr.Plot() loss_viewer = gr.Plot()
start_btn.click( input_components = [
runner.run_train, top_elems["lang"],
[ top_elems["model_name"],
top_elems["lang"], top_elems["checkpoints"],
top_elems["model_name"], top_elems["finetuning_type"],
top_elems["checkpoints"], top_elems["quantization_bit"],
top_elems["finetuning_type"], top_elems["template"],
top_elems["quantization_bit"], top_elems["system_prompt"],
top_elems["template"], training_stage,
top_elems["source_prefix"], dataset_dir,
dataset_dir, dataset,
dataset, max_source_length,
max_source_length, max_target_length,
max_target_length, learning_rate,
learning_rate, num_train_epochs,
num_train_epochs, max_samples,
max_samples, batch_size,
batch_size, gradient_accumulation_steps,
gradient_accumulation_steps, lr_scheduler_type,
lr_scheduler_type, max_grad_norm,
max_grad_norm, val_size,
dev_ratio, logging_steps,
logging_steps, save_steps,
save_steps, warmup_steps,
warmup_steps, compute_type,
compute_type, padding_side,
lora_rank, lora_rank,
lora_dropout, lora_dropout,
lora_target, lora_target,
output_dir resume_lora_training,
], dpo_beta,
[output_box] reward_model,
) output_dir
]
output_components = [
output_box,
process_bar
]
cmd_preview_btn.click(runner.preview_train, input_components, output_components)
start_btn.click(runner.run_train, input_components, output_components)
stop_btn.click(runner.set_abort, queue=False) stop_btn.click(runner.set_abort, queue=False)
output_box.change( process_bar.change(
gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
) )
return dict( return dict(
training_stage=training_stage,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=dataset, dataset=dataset,
preview_btn=preview_btn, data_preview_btn=data_preview_btn,
preview_count=preview_count, preview_count=preview_count,
preview_samples=preview_samples, preview_samples=preview_samples,
close_btn=close_btn, close_btn=close_btn,
@@ -122,16 +159,23 @@ def create_sft_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dict[
gradient_accumulation_steps=gradient_accumulation_steps, gradient_accumulation_steps=gradient_accumulation_steps,
lr_scheduler_type=lr_scheduler_type, lr_scheduler_type=lr_scheduler_type,
max_grad_norm=max_grad_norm, max_grad_norm=max_grad_norm,
dev_ratio=dev_ratio, val_size=val_size,
advanced_tab=advanced_tab, advanced_tab=advanced_tab,
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
compute_type=compute_type, compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab, lora_tab=lora_tab,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target, lora_target=lora_target,
resume_lora_training=resume_lora_training,
rlhf_tab=rlhf_tab,
dpo_beta=dpo_beta,
reward_model=reward_model,
refresh_btn=refresh_btn,
cmd_preview_btn=cmd_preview_btn,
start_btn=start_btn, start_btn=start_btn,
stop_btn=stop_btn, stop_btn=stop_btn,
output_dir=output_dir, output_dir=output_dir,

View File

@@ -3,11 +3,13 @@ from transformers.utils.versions import require_version
from llmtuner.webui.components import ( from llmtuner.webui.components import (
create_top, create_top,
create_sft_tab, create_train_tab,
create_eval_tab, create_eval_tab,
create_infer_tab, create_infer_tab,
create_export_tab create_export_tab,
create_chat_box
) )
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.css import CSS from llmtuner.webui.css import CSS
from llmtuner.webui.manager import Manager from llmtuner.webui.manager import Manager
from llmtuner.webui.runner import Runner from llmtuner.webui.runner import Runner
@@ -22,8 +24,8 @@ def create_ui() -> gr.Blocks:
with gr.Blocks(title="Web Tuner", css=CSS) as demo: with gr.Blocks(title="Web Tuner", css=CSS) as demo:
top_elems = create_top() top_elems = create_top()
with gr.Tab("SFT"): with gr.Tab("Train"):
sft_elems = create_sft_tab(top_elems, runner) train_elems = create_train_tab(top_elems, runner)
with gr.Tab("Evaluate"): with gr.Tab("Evaluate"):
eval_elems = create_eval_tab(top_elems, runner) eval_elems = create_eval_tab(top_elems, runner)
@@ -34,7 +36,7 @@ def create_ui() -> gr.Blocks:
with gr.Tab("Export"): with gr.Tab("Export"):
export_elems = create_export_tab(top_elems) export_elems = create_export_tab(top_elems)
elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] elem_list = [top_elems, train_elems, eval_elems, infer_elems, export_elems]
manager = Manager(elem_list) manager = Manager(elem_list)
demo.load( demo.load(
@@ -47,11 +49,29 @@ def create_ui() -> gr.Blocks:
manager.gen_label, manager.gen_label,
[top_elems["lang"]], [top_elems["lang"]],
[elem for elems in elem_list for elem in elems.values()], [elem for elems in elem_list for elem in elems.values()],
queue=False
) )
return demo return demo
def create_web_demo() -> gr.Blocks:
chat_model = WebChatModel(lazy_init=False)
with gr.Blocks(title="Web Demo", css=CSS) as demo:
lang = gr.Dropdown(choices=["en", "zh"])
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.select(manager.gen_label, [lang], [lang] + list(chat_elems.values()), queue=False)
return demo
if __name__ == "__main__": if __name__ == "__main__":
demo = create_ui() demo = create_ui()
demo.queue() demo.queue()

View File

@@ -77,7 +77,7 @@ LOCALES = {
"info": "构建提示词时使用的模板" "info": "构建提示词时使用的模板"
} }
}, },
"source_prefix": { "system_prompt": {
"en": { "en": {
"label": "System prompt (optional)", "label": "System prompt (optional)",
"info": "A sequence used as the default system prompt." "info": "A sequence used as the default system prompt."
@@ -87,6 +87,16 @@ LOCALES = {
"info": "默认使用的系统提示词" "info": "默认使用的系统提示词"
} }
}, },
"training_stage": {
"en": {
"label": "Stage",
"info": "The stage to perform in training."
},
"zh": {
"label": "训练阶段",
"info": "目前采用的训练方式。"
}
},
"dataset_dir": { "dataset_dir": {
"en": { "en": {
"label": "Data dir", "label": "Data dir",
@@ -105,12 +115,12 @@ LOCALES = {
"label": "数据集" "label": "数据集"
} }
}, },
"preview_btn": { "data_preview_btn": {
"en": { "en": {
"value": "Preview" "value": "Preview dataset"
}, },
"zh": { "zh": {
"value": "预览" "value": "预览数据集"
} }
}, },
"preview_count": { "preview_count": {
@@ -227,9 +237,9 @@ LOCALES = {
"info": "用于梯度裁剪的范数。" "info": "用于梯度裁剪的范数。"
} }
}, },
"dev_ratio": { "val_size": {
"en": { "en": {
"label": "Dev ratio", "label": "Val size",
"info": "Proportion of data in the dev set." "info": "Proportion of data in the dev set."
}, },
"zh": { "zh": {
@@ -277,6 +287,16 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。" "info": "是否启用 FP16 或 BF16 混合精度训练。"
} }
}, },
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": { "lora_tab": {
"en": { "en": {
"label": "LoRA configurations" "label": "LoRA configurations"
@@ -315,6 +335,52 @@ LOCALES = {
"info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。" "info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
} }
}, },
"resume_lora_training": {
"en": {
"label": "Resume LoRA training",
"info": "Whether to resume training from the last LoRA weights or create new lora weights."
},
"zh": {
"label": "继续上次的训练",
"info": "接着上次的 LoRA 权重训练或创建一个新的 LoRA 权重。"
}
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations"
},
"zh": {
"label": "RLHF 参数设置"
}
},
"dpo_beta": {
"en": {
"label": "DPO beta",
"info": "Value of the beta parameter in the DPO loss."
},
"zh": {
"label": "DPO beta 参数",
"info": "DPO 损失函数中 beta 超参数大小。"
}
},
"reward_model": {
"en": {
"label": "Reward model",
"info": "Checkpoint of the reward model for PPO training."
},
"zh": {
"label": "奖励模型",
"info": "PPO 训练中奖励模型的断点路径。"
}
},
"cmd_preview_btn": {
"en": {
"value": "Preview command"
},
"zh": {
"value": "预览命令"
}
},
"start_btn": { "start_btn": {
"en": { "en": {
"value": "Start" "value": "Start"
@@ -389,7 +455,7 @@ LOCALES = {
"value": "模型未加载,请先加载模型。" "value": "模型未加载,请先加载模型。"
} }
}, },
"prefix": { "system": {
"en": { "en": {
"placeholder": "System prompt (optional)" "placeholder": "System prompt (optional)"
}, },
@@ -513,6 +579,10 @@ ALERTS = {
"en": "Please provide export dir.", "en": "Please provide export dir.",
"zh": "请填写导出目录" "zh": "请填写导出目录"
}, },
"err_failed": {
"en": "Failed.",
"zh": "训练出错。"
},
"info_aborting": { "info_aborting": {
"en": "Aborted, wait for terminating...", "en": "Aborted, wait for terminating...",
"zh": "训练中断,正在等待线程结束……" "zh": "训练中断,正在等待线程结束……"

View File

@@ -12,12 +12,18 @@ class Manager:
def __init__(self, elem_list: List[Dict[str, Component]]): def __init__(self, elem_list: List[Dict[str, Component]]):
self.elem_list = elem_list self.elem_list = elem_list
def gen_refresh(self) -> Dict[str, Any]: def gen_refresh(self, lang: str) -> Dict[str, Any]:
refresh_dict = { refresh_dict = {
"dataset": {"choices": list_dataset()["choices"]}, "dataset": {"choices": list_dataset()["choices"]},
"output_dir": {"value": get_time()} "output_dir": {"value": get_time()}
} }
user_config = load_config() user_config = load_config()
if lang:
refresh_dict["lang"] = {"value": lang}
else:
refresh_dict["lang"] = {"value": user_config["lang"] if user_config["lang"] else "en"}
if user_config["last_model"]: if user_config["last_model"]:
refresh_dict["model_name"] = {"value": user_config["last_model"]} refresh_dict["model_name"] = {"value": user_config["last_model"]}
refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
@@ -26,10 +32,12 @@ class Manager:
def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING def gen_label(self, lang: str) -> Dict[Component, Dict[str, Any]]: # cannot use TYPE_CHECKING
update_dict = {} update_dict = {}
refresh_dict = self.gen_refresh() refresh_dict = self.gen_refresh(lang)
for elems in self.elem_list: for elems in self.elem_list:
for name, component in elems.items(): for name, component in elems.items():
update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) update_dict[component] = gr.update(
**LOCALES[name][refresh_dict["lang"]["value"]], **refresh_dict.get(name, {})
)
return update_dict return update_dict

View File

@@ -1,18 +1,20 @@
import gradio as gr
import logging import logging
import os import os
import threading import threading
import time import time
import transformers import transformers
from typing import Generator, List, Optional, Tuple from transformers.trainer import TRAINING_ARGS_NAME
from typing import Any, Dict, Generator, List, Tuple
from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.callbacks import LogCallback
from llmtuner.extras.constants import DEFAULT_MODULE from llmtuner.extras.constants import DEFAULT_MODULE
from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.logging import LoggerHandler
from llmtuner.extras.misc import torch_gc from llmtuner.extras.misc import torch_gc
from llmtuner.tuner import get_train_args, run_sft from llmtuner.tuner import run_exp
from llmtuner.webui.common import get_model_path, get_save_dir from llmtuner.webui.common import get_model_path, get_save_dir
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
from llmtuner.webui.utils import format_info, get_eval_results from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar
class Runner: class Runner:
@@ -20,49 +22,46 @@ class Runner:
def __init__(self): def __init__(self):
self.aborted = False self.aborted = False
self.running = False self.running = False
self.logger_handler = LoggerHandler()
self.logger_handler.setLevel(logging.INFO)
logging.root.addHandler(self.logger_handler)
transformers.logging.add_handler(self.logger_handler)
def set_abort(self): def set_abort(self):
self.aborted = True self.aborted = True
self.running = False self.running = False
def initialize( def _initialize(
self, lang: str, model_name: str, dataset: List[str] self, lang: str, model_name: str, dataset: List[str]
) -> Tuple[str, str, LoggerHandler, LogCallback]: ) -> str:
if self.running: if self.running:
return None, ALERTS["err_conflict"][lang], None, None return ALERTS["err_conflict"][lang]
if not model_name: if not model_name:
return None, ALERTS["err_no_model"][lang], None, None return ALERTS["err_no_model"][lang]
model_name_or_path = get_model_path(model_name) if not get_model_path(model_name):
if not model_name_or_path: return ALERTS["err_no_path"][lang]
return None, ALERTS["err_no_path"][lang], None, None
if len(dataset) == 0: if len(dataset) == 0:
return None, ALERTS["err_no_dataset"][lang], None, None return ALERTS["err_no_dataset"][lang]
self.aborted = False self.aborted = False
self.running = True self.logger_handler.reset()
self.trainer_callback = LogCallback(self)
return ""
logger_handler = LoggerHandler() def _finalize(
logger_handler.setLevel(logging.INFO) self, lang: str, finish_info: str
logging.root.addHandler(logger_handler)
transformers.logging.add_handler(logger_handler)
trainer_callback = LogCallback(self)
return model_name_or_path, "", logger_handler, trainer_callback
def finalize(
self, lang: str, finish_info: Optional[str] = None
) -> str: ) -> str:
self.running = False self.running = False
torch_gc() torch_gc()
if self.aborted: if self.aborted:
return ALERTS["info_aborted"][lang] return ALERTS["info_aborted"][lang]
else: else:
return finish_info if finish_info is not None else ALERTS["info_finished"][lang] return finish_info
def run_train( def _parse_train_args(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@@ -70,7 +69,8 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
training_stage: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@@ -82,37 +82,39 @@ class Runner:
gradient_accumulation_steps: int, gradient_accumulation_steps: int,
lr_scheduler_type: str, lr_scheduler_type: str,
max_grad_norm: str, max_grad_norm: str,
dev_ratio: float, val_size: float,
logging_steps: int, logging_steps: int,
save_steps: int, save_steps: int,
warmup_steps: int, warmup_steps: int,
compute_type: str, compute_type: str,
padding_side: str,
lora_rank: int, lora_rank: int,
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
resume_lora_training: bool,
dpo_beta: float,
reward_model: str,
output_dir: str output_dir: str
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, ckpt) for ckpt in checkpoints]
) )
else: else:
checkpoint_dir = None checkpoint_dir = None
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
args = dict( args = dict(
model_name_or_path=model_name_or_path, stage="sft",
model_name_or_path=get_model_path(model_name),
do_train=True, do_train=True,
overwrite_cache=True, overwrite_cache=True,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,
@@ -127,42 +129,40 @@ class Runner:
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
fp16=(compute_type == "fp16"), padding_side=padding_side,
bf16=(compute_type == "bf16"),
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) resume_lora_training=resume_lora_training,
output_dir=output_dir
) )
args[compute_type] = True
if dev_ratio > 1e-6: if training_stage == "Reward Modeling":
args["dev_ratio"] = dev_ratio args["stage"] = "rm"
args["resume_lora_training"] = False
elif training_stage == "PPO":
args["stage"] = "ppo"
args["resume_lora_training"] = False
args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0
elif training_stage == "DPO":
args["stage"] = "dpo"
args["resume_lora_training"] = False
args["dpo_beta"] = dpo_beta
elif training_stage == "Pre-Training":
args["stage"] = "pt"
if val_size > 1e-6:
args["val_size"] = val_size
args["evaluation_strategy"] = "steps" args["evaluation_strategy"] = "steps"
args["eval_steps"] = save_steps args["eval_steps"] = save_steps
args["load_best_model_at_end"] = True args["load_best_model_at_end"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) return lang, model_name, dataset, output_dir, args
run_args = dict( def _parse_eval_args(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetuning_args=finetuning_args,
callbacks=[trainer_callback]
)
thread = threading.Thread(target=run_sft, kwargs=run_args)
thread.start()
while thread.is_alive():
time.sleep(1)
if self.aborted:
yield ALERTS["info_aborting"][lang]
else:
yield format_info(logger_handler.log, trainer_callback.tracker)
yield self.finalize(lang)
def run_eval(
self, self,
lang: str, lang: str,
model_name: str, model_name: str,
@@ -170,7 +170,7 @@ class Runner:
finetuning_type: str, finetuning_type: str,
quantization_bit: str, quantization_bit: str,
template: str, template: str,
source_prefix: str, system_prompt: str,
dataset_dir: str, dataset_dir: str,
dataset: List[str], dataset: List[str],
max_source_length: int, max_source_length: int,
@@ -178,12 +178,7 @@ class Runner:
max_samples: str, max_samples: str,
batch_size: int, batch_size: int,
predict: bool predict: bool
) -> Generator[str, None, None]: ) -> Tuple[str, str, List[str], str, Dict[str, Any]]:
model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
if error:
yield error
return
if checkpoints: if checkpoints:
checkpoint_dir = ",".join( checkpoint_dir = ",".join(
[os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
@@ -194,15 +189,16 @@ class Runner:
output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
args = dict( args = dict(
model_name_or_path=model_name_or_path, stage="sft",
model_name_or_path=get_model_path(model_name),
do_eval=True, do_eval=True,
overwrite_cache=True, overwrite_cache=True,
predict_with_generate=True, predict_with_generate=True,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type, finetuning_type=finetuning_type,
quantization_bit=int(quantization_bit) if quantization_bit else None, quantization_bit=int(quantization_bit) if quantization_bit != "None" else None,
template=template, template=template,
source_prefix=source_prefix, system_prompt=system_prompt,
dataset_dir=dataset_dir, dataset_dir=dataset_dir,
dataset=",".join(dataset), dataset=",".join(dataset),
max_source_length=max_source_length, max_source_length=max_source_length,
@@ -216,23 +212,72 @@ class Runner:
args.pop("do_eval", None) args.pop("do_eval", None)
args["do_predict"] = True args["do_predict"] = True
model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) return lang, model_name, dataset, output_dir, args
run_args = dict( def preview_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
model_args=model_args, lang, model_name, dataset, _, args = self._parse_train_args(*args)
data_args=data_args, error = self._initialize(lang, model_name, dataset)
training_args=training_args, if error:
finetuning_args=finetuning_args, yield error, gr.update(visible=False)
callbacks=[trainer_callback] else:
) yield gen_cmd(args), gr.update(visible=False)
thread = threading.Thread(target=run_sft, kwargs=run_args)
def preview_eval(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, _, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
else:
yield gen_cmd(args), gr.update(visible=False)
def run_train(self, *args) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
lang, model_name, dataset, output_dir, args = self._parse_train_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start() thread.start()
while thread.is_alive(): while thread.is_alive():
time.sleep(1) time.sleep(2)
if self.aborted: if self.aborted:
yield ALERTS["info_aborting"][lang] yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else: else:
yield format_info(logger_handler.log, trainer_callback.tracker) yield self.logger_handler.log, update_process_bar(self.trainer_callback)
yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) if os.path.exists(os.path.join(output_dir, TRAINING_ARGS_NAME)):
finish_info = ALERTS["info_finished"][lang]
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)
def run_eval(self, *args) -> Generator[str, None, None]:
lang, model_name, dataset, output_dir, args = self._parse_eval_args(*args)
error = self._initialize(lang, model_name, dataset)
if error:
yield error, gr.update(visible=False)
return
self.running = True
run_kwargs = dict(args=args, callbacks=[self.trainer_callback])
thread = threading.Thread(target=run_exp, kwargs=run_kwargs)
thread.start()
while thread.is_alive():
time.sleep(2)
if self.aborted:
yield ALERTS["info_aborting"][lang], gr.update(visible=False)
else:
yield self.logger_handler.log, update_process_bar(self.trainer_callback)
if os.path.exists(os.path.join(output_dir, "all_results.json")):
finish_info = get_eval_results(os.path.join(output_dir, "all_results.json"))
else:
finish_info = ALERTS["err_failed"][lang]
yield self._finalize(lang, finish_info), gr.update(visible=False)

View File

@@ -3,22 +3,30 @@ import json
import gradio as gr import gradio as gr
import matplotlib.figure import matplotlib.figure
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from typing import Any, Dict, Generator, List, Tuple from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple
from datetime import datetime from datetime import datetime
from llmtuner.extras.ploting import smooth from llmtuner.extras.ploting import smooth
from llmtuner.tuner import get_infer_args, load_model_and_tokenizer from llmtuner.tuner import export_model
from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG from llmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
from llmtuner.webui.locales import ALERTS from llmtuner.webui.locales import ALERTS
if TYPE_CHECKING:
from llmtuner.extras.callbacks import LogCallback
def format_info(log: str, tracker: dict) -> str:
info = log def update_process_bar(callback: "LogCallback") -> Dict[str, Any]:
if "current_steps" in tracker: if not callback.max_steps:
info += "Running **{:d}/{:d}**: {} < {}\n".format( return gr.update(visible=False)
tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
) percentage = round(100 * callback.cur_steps / callback.max_steps, 0) if callback.max_steps != 0 else 100.0
return info label = "Running {:d}/{:d}: {} < {}".format(
callback.cur_steps,
callback.max_steps,
callback.elapsed_time,
callback.remaining_time
)
return gr.update(label=label, value=percentage, visible=True)
def get_time() -> str: def get_time() -> str:
@@ -54,6 +62,18 @@ def can_quantize(finetuning_type: str) -> Dict[str, Any]:
return gr.update(interactive=True) return gr.update(interactive=True)
def gen_cmd(args: Dict[str, Any]) -> str:
if args.get("do_train", None):
args["plot_loss"] = True
cmd_lines = ["CUDA_VISIBLE_DEVICES=0 python "]
for k, v in args.items():
if v is not None and v != "":
cmd_lines.append(" --{} {} ".format(k, str(v)))
cmd_text = "\\\n".join(cmd_lines)
cmd_text = "```bash\n{}\n```".format(cmd_text)
return cmd_text
def get_eval_results(path: os.PathLike) -> str: def get_eval_results(path: os.PathLike) -> str:
with open(path, "r", encoding="utf-8") as f: with open(path, "r", encoding="utf-8") as f:
result = json.dumps(json.load(f), indent=4) result = json.dumps(json.load(f), indent=4)
@@ -87,8 +107,14 @@ def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotl
return fig return fig
def export_model( def save_model(
lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str lang: str,
model_name: str,
checkpoints: List[str],
finetuning_type: str,
template: str,
max_shard_size: int,
save_dir: str
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
if not model_name: if not model_name:
yield ALERTS["err_no_model"][lang] yield ALERTS["err_no_model"][lang]
@@ -114,12 +140,11 @@ def export_model(
args = dict( args = dict(
model_name_or_path=model_name_or_path, model_name_or_path=model_name_or_path,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
finetuning_type=finetuning_type finetuning_type=finetuning_type,
template=template,
output_dir=save_dir
) )
yield ALERTS["info_exporting"][lang] yield ALERTS["info_exporting"][lang]
model_args, _, finetuning_args, _ = get_infer_args(args) export_model(args, max_shard_size="{}GB".format(max_shard_size))
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
tokenizer.save_pretrained(save_dir)
yield ALERTS["info_exported"][lang] yield ALERTS["info_exported"][lang]

View File

@@ -1,17 +1,8 @@
from llmtuner.tuner import get_train_args, run_pt, run_sft, run_rm, run_ppo from llmtuner import run_exp
def main(): def main():
model_args, data_args, training_args, finetuning_args, general_args = get_train_args() run_exp()
if general_args.stage == "pt":
run_pt(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "sft":
run_sft(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "rm":
run_rm(model_args, data_args, training_args, finetuning_args)
elif general_args.stage == "ppo":
run_ppo(model_args, data_args, training_args, finetuning_args)
def _mp_fn(index): def _mp_fn(index):

View File

@@ -1,4 +1,4 @@
from llmtuner.webui.interface import create_ui from llmtuner import create_ui
def main(): def main():

View File

@@ -1,33 +1,8 @@
# coding=utf-8 from llmtuner import create_web_demo
# Implements user interface in browser for fine-tuned models.
# Usage: python web_demo.py --model_name_or_path path_to_model --checkpoint_dir path_to_checkpoint
import gradio as gr
from transformers.utils.versions import require_version
from llmtuner.tuner import get_infer_args
from llmtuner.webui.chat import WebChatModel
from llmtuner.webui.components.chatbot import create_chat_box
from llmtuner.webui.manager import Manager
require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
def main(): def main():
chat_model = WebChatModel(*get_infer_args()) demo = create_web_demo()
with gr.Blocks(title="Web Demo") as demo:
lang = gr.Dropdown(choices=["en", "zh"], value="en")
_, _, _, chat_elems = create_chat_box(chat_model, visible=True)
manager = Manager([{"lang": lang}, chat_elems])
demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
demo.queue() demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)