support DDP in webui
Former-commit-id: d059262ff8dc857f597d2657546ec625726a664a
This commit is contained in:
@@ -1,9 +1,16 @@
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
from enum import Enum, unique
|
||||
|
||||
from llamafactory import launcher
|
||||
|
||||
from .api.app import run_api
|
||||
from .chat.chat_model import run_chat
|
||||
from .eval.evaluator import run_eval
|
||||
from .extras.logging import get_logger
|
||||
from .extras.misc import get_device_count
|
||||
from .train.tuner import export_model, run_exp
|
||||
from .webui.interface import run_web_demo, run_web_ui
|
||||
|
||||
@@ -37,6 +44,8 @@ WELCOME = (
|
||||
+ "-" * 58
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@unique
|
||||
class Command(str, Enum):
|
||||
@@ -62,7 +71,32 @@ def main():
|
||||
elif command == Command.EXPORT:
|
||||
export_model()
|
||||
elif command == Command.TRAIN:
|
||||
run_exp()
|
||||
if get_device_count() > 1:
|
||||
nnodes = os.environ.get("NNODES", "1")
|
||||
node_rank = os.environ.get("RANK", "0")
|
||||
nproc_per_node = os.environ.get("NPROC_PER_NODE", str(get_device_count()))
|
||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info("Initializing distributed tasks at: {}:{}".format(master_addr, master_port))
|
||||
subprocess.run(
|
||||
[
|
||||
"torchrun",
|
||||
"--nnodes",
|
||||
nnodes,
|
||||
"--node_rank",
|
||||
node_rank,
|
||||
"--nproc_per_node",
|
||||
nproc_per_node,
|
||||
"--master_addr",
|
||||
master_addr,
|
||||
"--master_port",
|
||||
master_port,
|
||||
launcher.__file__,
|
||||
*sys.argv[1:],
|
||||
]
|
||||
)
|
||||
else:
|
||||
run_exp()
|
||||
elif command == Command.WEBDEMO:
|
||||
run_web_demo()
|
||||
elif command == Command.WEBUI:
|
||||
|
||||
Reference in New Issue
Block a user