[misc] fix cli (#7204)
Former-commit-id: 999f57133ca163c7108d2d5ee8194eca9b2109b4
This commit is contained in:
@@ -88,18 +88,24 @@ def main():
|
||||
elif command == Command.TRAIN:
|
||||
force_torchrun = is_env_enabled("FORCE_TORCHRUN")
|
||||
if force_torchrun or (get_device_count() > 1 and not use_ray()):
|
||||
nnodes = os.getenv("NNODES", "1")
|
||||
node_rank = os.getenv("NODE_RANK", "0")
|
||||
nproc_per_node = os.getenv("NPROC_PER_NODE", str(get_device_count()))
|
||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||
logger.info_rank0(f"Initializing {nproc_per_node} distributed tasks at: {master_addr}:{master_port}")
|
||||
if int(nnodes) > 1:
|
||||
print(f"Multi-node training enabled: num nodes: {nnodes}, node rank: {node_rank}")
|
||||
|
||||
process = subprocess.run(
|
||||
(
|
||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||
)
|
||||
.format(
|
||||
nnodes=os.getenv("NNODES", "1"),
|
||||
node_rank=os.getenv("NODE_RANK", "0"),
|
||||
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||
nnodes=nnodes,
|
||||
node_rank=node_rank,
|
||||
nproc_per_node=nproc_per_node,
|
||||
master_addr=master_addr,
|
||||
master_port=master_port,
|
||||
file_name=launcher.__file__,
|
||||
@@ -119,7 +125,7 @@ def main():
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown command: {command}.")
|
||||
print(f"Unknown command: {command}.\n{USAGE}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user