mirror of
https://github.com/hiyouga/LlamaFactory.git
synced 2026-02-03 08:53:38 +00:00
[logging] Fix race condition in LoggerHandler during multi-GPU training (#10156)
Co-authored-by: yurekami <yurekami@users.noreply.github.com>
This commit is contained in:
@@ -41,12 +41,13 @@ class LoggerHandler(logging.Handler):
|
|||||||
datefmt="%Y-%m-%d %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
self.setLevel(logging.INFO)
|
self.setLevel(logging.INFO)
|
||||||
|
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||||
if os.path.exists(self.running_log):
|
try:
|
||||||
os.remove(self.running_log)
|
os.remove(self.running_log)
|
||||||
|
except OSError:
|
||||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
pass
|
||||||
|
|
||||||
def _write_log(self, log_entry: str) -> None:
|
def _write_log(self, log_entry: str) -> None:
|
||||||
with open(self.running_log, "a", encoding="utf-8") as f:
|
with open(self.running_log, "a", encoding="utf-8") as f:
|
||||||
|
|||||||
Reference in New Issue
Block a user