[infer] set env for vllm ascend (#7745)

This commit is contained in:
hoshi-hiyouga
2025-04-17 01:08:55 +08:00
committed by GitHub
parent 2e518f255f
commit d222f63cb7
5 changed files with 28 additions and 21 deletions

View File

@@ -16,10 +16,8 @@ import os
import subprocess
import sys
from copy import deepcopy
from enum import Enum, unique
from functools import partial
from .extras import logging
USAGE = (
@@ -37,19 +35,20 @@ USAGE = (
+ "-" * 70
)
logger = logging.get_logger(__name__)
def main():
from . import launcher
from .api.app import run_api
from .chat.chat_model import run_chat
from .eval.evaluator import run_eval
from .extras import logging
from .extras.env import VERSION, print_env
from .extras.misc import find_available_port, get_device_count, is_env_enabled, use_ray
from .train.tuner import export_model, run_exp
from .webui.interface import run_web_demo, run_web_ui
logger = logging.get_logger(__name__)
WELCOME = (
"-" * 58
+ "\n"
@@ -62,7 +61,7 @@ def main():
+ "-" * 58
)
COMMANDS = {
COMMAND_MAP = {
"api": run_api,
"chat": run_chat,
"env": print_env,
@@ -75,9 +74,9 @@ def main():
"help": partial(print, USAGE),
}
command = sys.argv.pop(1) if len(sys.argv) != 1 else "help"
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
if command == "train" and (force_torchrun or (get_device_count() > 1 and not use_ray())):
command = sys.argv.pop(1) if len(sys.argv) >= 1 else "help"
if command == "train" and (is_env_enabled("FORCE_TORCHRUN") or (get_device_count() > 1 and not use_ray())):
# launch distributed training
nnodes = os.getenv("NNODES", "1")
node_rank = os.getenv("NODE_RANK", "0")
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
@@ -113,11 +112,14 @@ def main():
check=True,
)
sys.exit(process.returncode)
elif command in COMMAND_MAP:
COMMAND_MAP[command]()
else:
COMMANDS[command]()
print(f"Unknown command: {command}.\n{USAGE}")
if __name__ == "__main__":
from multiprocessing import freeze_support
freeze_support()
main()