Source code for astra.thread_manager
from threading import Thread
from typing import Callable, Any, Dict, List
[docs]
class ThreadManager:
def __init__(self):
self.threads: List[Dict[str, Any]] = []
[docs]
def start_thread(
self,
target: Callable,
args: tuple = (),
thread_type: str = "",
device_name: str = "",
thread_id: Any = None,
daemon: bool = True,
) -> Thread:
th = Thread(target=target, args=args, daemon=daemon)
th.start()
thread_info = {
"type": thread_type,
"device_name": device_name,
"thread": th,
"id": thread_id if thread_id is not None else id(th),
}
self.threads.append(thread_info)
return th
[docs]
def join_thread(self, thread_id: Any) -> None:
for th_info in self.threads:
if th_info["id"] == thread_id:
th_info["thread"].join()
break
[docs]
def remove_dead_threads(self) -> None:
self.threads = [th for th in self.threads if th["thread"].is_alive()]
[docs]
def get_thread_ids(self) -> List[Any]:
return [th_info["id"] for th_info in self.threads]
[docs]
def get_thread(self, thread_id: Any) -> Thread | None:
for th_info in self.threads:
if th_info["id"] == thread_id:
return th_info["thread"]
return None
[docs]
def stop_thread(self, thread_id: Any) -> None:
for th_info in self.threads:
if th_info["id"] == thread_id and th_info["thread"].is_alive():
th_info["thread"].join()
self.threads.remove(th_info)
break
[docs]
def stop_all(self) -> None:
for th_info in self.threads:
if th_info["thread"].is_alive():
th_info["thread"].join()
self.threads.clear()
[docs]
def is_thread_running(self, schedule: str) -> bool:
"""
Return True if any thread of the given type is currently alive.
"""
for th in self.threads:
if th["id"] == schedule and th["thread"].is_alive():
return True
return False
[docs]
def get_thread_summary(self) -> list[dict]:
"""
Return a summary list of all threads with type, device_name, and id.
"""
return [self._thread_summary(th_info) for th_info in self.threads]
def _thread_summary(self, thread_info: dict) -> dict:
return {
"type": thread_info.get("type"),
"device_name": thread_info.get("device_name"),
"id": thread_info.get("id"),
}