support DDP in webui

Former-commit-id: d059262ff8dc857f597d2657546ec625726a664a
This commit is contained in:
hiyouga
2024-05-28 19:24:22 +08:00
parent 9912b43fcc
commit 9138a7a5ba
19 changed files with 78 additions and 166 deletions

View File

@@ -1,9 +1,16 @@
import os
import random
import subprocess
import sys
from enum import Enum, unique
from llamafactory import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras.logging import get_logger
from .extras.misc import get_device_count
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
@@ -37,6 +44,8 @@ WELCOME = (
+ "-" * 58
)
logger = get_logger(__name__)
@unique
class Command(str, Enum):
@@ -62,7 +71,32 @@ def main():
elif command == Command.EXPORT:
export_model()
elif command == Command.TRAIN:
run_exp()
if get_device_count() > 1:
nnodes = os.environ.get("NNODES", "1")
node_rank = os.environ.get("RANK", "0")
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
subprocess.run(
[
"torchrun",
"--nnodes",
nnodes,
"--node_rank",
node_rank,
"--nproc_per_node",
nproc_per_node,
"--master_addr",
master_addr,
"--master_port",
master_port,
launcher.__file__,
*sys.argv[1:],
]
)
else:
run_exp()
elif command == Command.WEBDEMO:
run_web_demo()
elif command == Command.WEBUI: