# --------------------------------------------------------------------------------------#
# Copyright (c) 2026 MaxwellLink #
# This file is part of MaxwellLink. Repository: https://github.com/TaoELi/MaxwellLink #
# If you use this code, always credit and cite arXiv:2512.06173. #
# See AGENTS.md and README.md for details. #
# --------------------------------------------------------------------------------------#
"""
Two-layer socket aggregation for MaxwellLink.
This module adds an opt-in transport layer on top of the existing
``SocketHub`` implementation without modifying the original hub logic.
The new design introduces two roles:
- ``AggregatedSocketHub``: an EM-side hub that keeps the public hub API
expected by MaxwellLink solvers, but aggregates multiple molecule requests
into one upstream connection per HPC node.
- ``LocalSocketHubBridge``: a node-local bridge process/thread that talks to
``AggregatedSocketHub`` upstream while reusing an ordinary downstream
:class:`~maxwelllink.sockets.sockets.SocketHub` to fan out work to multiple
existing Python/socket-only drivers.
This preserves existing ``SocketHub`` behavior while enabling a two-layer
communication topology::
EM solver -> AggregatedSocketHub ==TCP==> LocalSocketHubBridge
-> local SocketHub ==TCP/UNIX==> many molecular drivers
"""
from __future__ import annotations
import argparse
from collections.abc import Iterable
import json
import os
import selectors
import socket
import struct
import time
import threading
from dataclasses import dataclass, field
from typing import Dict, Mapping, Optional
import numpy as np
from .sockets import (
DT_FLOAT,
BYE,
STOP,
SocketHub,
_ClientState,
_SocketClosed,
_recv_bytes,
_recv_msg,
_send_bytes,
_send_msg,
)
# ---------------------------------------------------------------------------
# Aggregate wire protocol
# ---------------------------------------------------------------------------
# All aggregate frames begin with a fixed 12-byte header (one of the banners
# below, right-padded with spaces). HELLO/INIT carry a JSON payload; STEP and
# RESULT use the packed binary layouts described next to their codecs. The byte
# layout is shared by the hub and bridge processes and must stay stable.
AGGHELLO = b"AGGHELLO"
AGGINIT = b"AGGINIT"
AGGREADY = b"AGGREADY"
AGGSTEP = b"AGGSTEP"
AGGRESULT = b"AGGRESULT"
AGGREGATION_INFO_VERSION = 1
_INT32 = struct.Struct("<i")
_STRUCT_3D = struct.Struct("<3d")
_INT32_LEN = _INT32.size # 4
_FIELD_LEN = _STRUCT_3D.size # 24: one packed efield/amp vector (3 doubles)
_AGG_HEADER_LEN = 12
_AGGSTEP_HDR = AGGSTEP.ljust(_AGG_HEADER_LEN, b" ")
_AGGRESULT_HDR = AGGRESULT.ljust(_AGG_HEADER_LEN, b" ")
# AGGSTEP head: header + nreq + nuniq; each member record: molecule_id + field_idx.
_AGGSTEP_HEAD_LEN = _AGG_HEADER_LEN + _INT32_LEN + _INT32_LEN
_AGGSTEP_RECORD_LEN = _INT32_LEN + _INT32_LEN
_STEP_FIELDIDX_OFF = _INT32_LEN # field_idx follows molecule_id within a record
# AGGRESULT head: header + nresp; each record: molecule_id + amp(vec3) + extra_len.
_AGGRESULT_HEAD_LEN = _AGG_HEADER_LEN + _INT32_LEN
_AGGRESULT_RECORD_LEN = _INT32_LEN + _FIELD_LEN + _INT32_LEN
_RESULT_AMP_OFF = _INT32_LEN # amp follows molecule_id
_RESULT_EXTRALEN_OFF = _INT32_LEN + _FIELD_LEN # extra_len follows amp
# ---------------------------------------------------------------------------
# Low-level helpers
# ---------------------------------------------------------------------------
def _json_dumps_bytes(payload: Mapping) -> bytes:
"""Encode a mapping into compact UTF-8 JSON bytes."""
return json.dumps(
payload,
ensure_ascii=False,
separators=(",", ":"),
sort_keys=True,
).encode("utf-8")
def _json_loads_bytes(payload: bytes) -> dict:
"""Decode a UTF-8 JSON payload, defaulting empty content to ``{}``."""
return json.loads(payload.decode("utf-8") or "{}")
def _recv_msg_with_timeout(sock: socket.socket, timeout: float) -> bytes:
"""
Receive one 12-byte MaxwellLink header using a temporary timeout.
This is used while discovering fresh bridge clients so the hub can poll
for their HELLO payload without blocking the whole EM-side wait loop.
"""
old_timeout = sock.gettimeout()
try:
sock.settimeout(timeout)
return _recv_msg(sock)
finally:
sock.settimeout(old_timeout)
# Selector (un)register raises these when a socket is unknown or already closed;
# they are always safe to ignore on a best-effort detach.
_SELECTOR_ERRORS = (KeyError, ValueError, OSError)
def _close_socket(sock: Optional[socket.socket]) -> None:
"""Close a socket, ignoring the error if it is already gone."""
if sock is None:
return
try:
sock.close()
except OSError:
pass
def _recv_exact_into(sock: socket.socket, buf, nbytes: int) -> None:
"""Read exactly ``nbytes`` into the start of ``buf``."""
mv = memoryview(buf)
got = 0
while got < nbytes:
nrecv = sock.recv_into(mv[got:nbytes], nbytes - got)
if nrecv == 0:
raise _SocketClosed("Peer closed")
got += nrecv
def _expect_header(buf, expected: bytes) -> None:
"""Validate the 12-byte banner at the start of ``buf``."""
got = bytes(memoryview(buf)[:_AGG_HEADER_LEN]).rstrip()
if got != expected:
raise RuntimeError(f"Expected {expected!r}, got {got!r}")
def _connect_tcp_with_retry(address: str, port: int, timeout: float) -> socket.socket:
"""Connect to a TCP server with bounded retries."""
deadline = time.monotonic() + float(timeout)
delay = 0.05
last_error = None
while True:
remaining = deadline - time.monotonic()
if remaining <= 0:
break
sock = socket.socket(socket.AF_INET)
try:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except (OSError, AttributeError):
pass
sock.settimeout(min(10.0, max(0.25, remaining)))
try:
sock.connect((address, port))
sock.settimeout(timeout)
return sock
except (ConnectionRefusedError, TimeoutError, socket.timeout, OSError) as exc:
last_error = exc
_close_socket(sock)
remaining = deadline - time.monotonic()
if remaining <= 0:
break
time.sleep(min(delay, remaining))
delay = min(delay * 1.5, 1.0)
raise TimeoutError(
f"Timed out connecting to aggregated hub at {(address, port)!r}"
) from last_error
# ---------------------------------------------------------------------------
# HELLO / INIT framing (JSON payloads)
# ---------------------------------------------------------------------------
def _send_aggregate_hello(sock: socket.socket, *, group_id: str) -> None:
"""Send the bridge HELLO banner used by the aggregate protocol."""
_send_msg(sock, AGGHELLO)
_send_bytes(sock, _json_dumps_bytes({"group_id": str(group_id), "version": 1}))
def _send_aggregate_init(
sock: socket.socket,
*,
group_id: str,
init_payloads: Mapping[int, dict],
) -> None:
"""Send group membership plus per-molecule INIT payloads to a bridge."""
payload = {
"group_id": str(group_id),
"molecule_ids": [int(mid) for mid in init_payloads.keys()],
"init_payloads": {
str(int(mid)): {**dict(init_payloads[mid]), "molecule_id": int(mid)}
for mid in init_payloads.keys()
},
}
_send_msg(sock, AGGINIT)
_send_bytes(sock, _json_dumps_bytes(payload))
# ---------------------------------------------------------------------------
# STEP / RESULT codecs (packed binary frames)
# ---------------------------------------------------------------------------
class _FrameCodec:
"""
Base class holding reusable scratch buffers for the packed frame codecs.
A codec instance is used in only one direction (the hub sends and the
bridge receives, or vice versa), so the named scratch buffers requested via
:meth:`_scratch` are reused across calls to avoid per-step allocations.
"""
def __init__(self) -> None:
self._scratch_buffers: Dict[str, bytearray] = {}
def _scratch(self, name: str, size: int) -> bytearray:
"""Return a reusable named buffer holding at least ``size`` bytes."""
buf = self._scratch_buffers.get(name)
if buf is None or len(buf) < size:
buf = bytearray(size)
self._scratch_buffers[name] = buf
return buf
class _StepCodec(_FrameCodec):
"""
Encoder/decoder for the AGGSTEP fan-out frame.
The hub encodes (``send``) and the bridge decodes (``recv``).
Frame layout::
[ header(12) | nreq(i32) | nuniq(i32) ]
[ nuniq * field(3 doubles) ]
[ nreq * (molecule_id(i32), field_idx(i32)) ]
Repeated efields are de-duplicated so molecules sharing a field reference
the same packed vector by index.
"""
def send(
self, sock: socket.socket, requests: Mapping[int, Mapping[str, np.ndarray]]
) -> None:
"""Pack and send one grouped fan-out step as a single frame."""
unique_fields: list[tuple[float, float, float]] = []
field_to_idx: dict[tuple[float, float, float], int] = {}
members: list[tuple[int, int]] = []
for mid, payload in requests.items():
field = np.asarray(payload["efield_au"], dtype=DT_FLOAT).reshape(3)
key = (float(field[0]), float(field[1]), float(field[2]))
field_idx = field_to_idx.get(key)
if field_idx is None:
field_idx = len(unique_fields)
unique_fields.append(key)
field_to_idx[key] = field_idx
members.append((int(mid), field_idx))
frame_len = (
_AGGSTEP_HEAD_LEN
+ _FIELD_LEN * len(unique_fields)
+ _AGGSTEP_RECORD_LEN * len(members)
)
buf = self._scratch("send", frame_len)
buf[:_AGG_HEADER_LEN] = _AGGSTEP_HDR
_INT32.pack_into(buf, _AGG_HEADER_LEN, len(members))
_INT32.pack_into(buf, _AGG_HEADER_LEN + _INT32_LEN, len(unique_fields))
offset = _AGGSTEP_HEAD_LEN
for fx, fy, fz in unique_fields:
_STRUCT_3D.pack_into(buf, offset, fx, fy, fz)
offset += _FIELD_LEN
for mid, field_idx in members:
_INT32.pack_into(buf, offset, mid)
_INT32.pack_into(buf, offset + _STEP_FIELDIDX_OFF, field_idx)
offset += _AGGSTEP_RECORD_LEN
sock.sendall(memoryview(buf)[:frame_len])
def recv(
self, sock: socket.socket, *, header_already_read: bool = False
) -> Dict[int, np.ndarray]:
"""
Receive one grouped fan-out step.
``header_already_read`` is set when the caller already consumed the
12-byte banner (e.g. the bridge's main dispatch loop) so only the rest
of the header is read here.
"""
head = self._scratch("head", _AGGSTEP_HEAD_LEN)
if header_already_read:
head[:_AGG_HEADER_LEN] = _AGGSTEP_HDR
_recv_exact_into(
sock,
memoryview(head)[_AGG_HEADER_LEN:],
_AGGSTEP_HEAD_LEN - _AGG_HEADER_LEN,
)
else:
_recv_exact_into(sock, head, _AGGSTEP_HEAD_LEN)
_expect_header(head, AGGSTEP)
nreq = _INT32.unpack_from(head, _AGG_HEADER_LEN)[0]
nuniq = _INT32.unpack_from(head, _AGG_HEADER_LEN + _INT32_LEN)[0]
body_len = _FIELD_LEN * nuniq + _AGGSTEP_RECORD_LEN * nreq
body = self._scratch("body", body_len)
if body_len:
_recv_exact_into(sock, body, body_len)
offset = 0
fields: list[np.ndarray] = []
for _ in range(nuniq):
fx, fy, fz = _STRUCT_3D.unpack_from(body, offset)
fields.append(np.array((fx, fy, fz), dtype=float))
offset += _FIELD_LEN
requests: Dict[int, np.ndarray] = {}
for _ in range(nreq):
mid = int(_INT32.unpack_from(body, offset)[0])
field_idx = _INT32.unpack_from(body, offset + _STEP_FIELDIDX_OFF)[0]
offset += _AGGSTEP_RECORD_LEN
requests[mid] = fields[field_idx]
return requests
class _ResultCodec(_FrameCodec):
"""
Encoder/decoder for the AGGRESULT reply frame.
The bridge encodes (``send``) and the hub decodes (``recv``).
Frame layout::
[ header(12) | nresp(i32) ]
[ nresp * (molecule_id(i32), amp(3 doubles), extra_len(i32)) ]
[ concatenated extra payload bytes ]
"""
def send(
self, sock: socket.socket, responses: Mapping[int, Mapping[str, object]]
) -> None:
"""Pack and send grouped molecule responses as a single frame."""
packed: list[tuple[int, tuple[float, float, float], bytes]] = []
total_extra = 0
for mid, payload in responses.items():
amp = np.asarray(payload["amp"], dtype=DT_FLOAT).reshape(3)
extra = payload.get("extra", b"")
if isinstance(extra, str):
extra = extra.encode("utf-8")
extra = bytes(extra)
packed.append(
(int(mid), (float(amp[0]), float(amp[1]), float(amp[2])), extra)
)
total_extra += len(extra)
fixed_len = _AGGRESULT_HEAD_LEN + _AGGRESULT_RECORD_LEN * len(packed)
frame_len = fixed_len + total_extra
buf = self._scratch("send", frame_len)
buf[:_AGG_HEADER_LEN] = _AGGRESULT_HDR
_INT32.pack_into(buf, _AGG_HEADER_LEN, len(packed))
offset = _AGGRESULT_HEAD_LEN
extra_offset = fixed_len
for mid, amp, extra in packed:
_INT32.pack_into(buf, offset, mid)
_STRUCT_3D.pack_into(buf, offset + _RESULT_AMP_OFF, amp[0], amp[1], amp[2])
_INT32.pack_into(buf, offset + _RESULT_EXTRALEN_OFF, len(extra))
offset += _AGGRESULT_RECORD_LEN
if extra:
buf[extra_offset : extra_offset + len(extra)] = extra
extra_offset += len(extra)
sock.sendall(memoryview(buf)[:frame_len])
def recv(self, sock: socket.socket) -> Dict[int, dict]:
"""Receive grouped molecule responses from a bridge."""
head = self._scratch("head", _AGGRESULT_HEAD_LEN)
_recv_exact_into(sock, head, _AGGRESULT_HEAD_LEN)
_expect_header(head, AGGRESULT)
nresp = _INT32.unpack_from(head, _AGG_HEADER_LEN)[0]
fixed_len = _AGGRESULT_RECORD_LEN * nresp
fixed = self._scratch("fixed", fixed_len)
if fixed_len:
_recv_exact_into(sock, fixed, fixed_len)
offset = 0
meta: list[tuple[int, tuple[float, float, float], int]] = []
total_extra = 0
for _ in range(nresp):
mid = int(_INT32.unpack_from(fixed, offset)[0])
amp = _STRUCT_3D.unpack_from(fixed, offset + _RESULT_AMP_OFF)
extra_len = _INT32.unpack_from(fixed, offset + _RESULT_EXTRALEN_OFF)[0]
meta.append((mid, amp, extra_len))
total_extra += extra_len
offset += _AGGRESULT_RECORD_LEN
extras = self._scratch("extras", total_extra)
if total_extra:
_recv_exact_into(sock, extras, total_extra)
responses: Dict[int, dict] = {}
extra_offset = 0
for mid, amp, extra_len in meta:
extra = (
bytes(memoryview(extras)[extra_offset : extra_offset + extra_len])
if extra_len
else b""
)
responses[mid] = {"amp": np.array(amp, dtype=float), "extra": extra}
extra_offset += extra_len
return responses
# ---------------------------------------------------------------------------
# Hub-side state and manifest specs
# ---------------------------------------------------------------------------
@dataclass
class _AggregateGroupState:
"""Per-bridge group state tracked by :class:`AggregatedSocketHub`."""
group_id: str
molecule_ids: list[int] = field(default_factory=list)
init_payloads: Dict[int, dict] = field(default_factory=dict)
bridge: Optional[_ClientState] = None
step_codec: _StepCodec = field(default_factory=_StepCodec)
result_codec: _ResultCodec = field(default_factory=_ResultCodec)
[docs]
@dataclass(frozen=True)
class RemoteBridgeSpec:
"""
One remote aggregate bridge entry produced by ``init_remote_bridges``.
Parameters
----------
idx : int
Zero-based bridge index used by :func:`run_bridge_node`.
group_id : str
Aggregate group identifier transmitted upstream.
unixsocket : str
Downstream UNIX-socket address local drivers should connect to.
n_molecules : int
Number of molecules assigned to this bridge.
"""
idx: int
group_id: str
unixsocket: str
n_molecules: int
[docs]
def to_dict(self) -> dict:
"""Return a JSON-serializable bridge specification mapping."""
return {
"idx": int(self.idx),
"group_id": str(self.group_id),
"unixsocket": str(self.unixsocket),
"n_molecules": int(self.n_molecules),
}
[docs]
@classmethod
def from_dict(cls, payload: Mapping) -> "RemoteBridgeSpec":
"""Build one bridge specification from JSON-decoded manifest data."""
return cls(
idx=int(payload["idx"]),
group_id=str(payload["group_id"]),
unixsocket=str(payload["unixsocket"]),
n_molecules=int(payload["n_molecules"]),
)
def _as_molecule_list(molecules) -> list:
"""Normalize one molecule or an iterable of molecules into a list."""
if hasattr(molecules, "init_payload"):
return [molecules]
if isinstance(molecules, Iterable) and not isinstance(
molecules, (str, bytes, bytearray)
):
return list(molecules)
raise TypeError(
"Expected one molecule or an iterable of molecules with 'init_payload'."
)
def _assign_molecule_to_group(
molecule,
*,
expected_hub: "AggregatedSocketHub",
group_id: str,
) -> None:
"""Assign one molecule to the given aggregate group in-place."""
if not hasattr(molecule, "init_payload"):
raise TypeError(
"Expected a molecule-like object carrying an 'init_payload' attribute."
)
molecule_hub = getattr(molecule, "hub", expected_hub)
if molecule_hub is not expected_hub:
raise ValueError(
"All molecules assigned to remote aggregate bridges must use the same hub."
)
payload = molecule.init_payload
if payload is None:
payload = {}
molecule.init_payload = payload
elif not isinstance(payload, dict):
payload = dict(payload)
molecule.init_payload = payload
previous = payload.get("aggregate_group")
if previous is not None and str(previous).strip() not in ("", group_id):
raise ValueError(
f"Molecule is already assigned to aggregate_group {previous!r}, "
f"cannot move it to {group_id!r}."
)
payload["aggregate_group"] = group_id
def _load_aggregation_info(info="aggregation.json") -> dict:
"""Load one JSON aggregation manifest from disk."""
with open(os.fspath(info), "r", encoding="utf-8") as f:
payload = json.load(f)
if not isinstance(payload, dict):
raise ValueError("Aggregation info file must contain a JSON object.")
return payload
def _coerce_remote_bridge_specs(payload: Mapping) -> list[RemoteBridgeSpec]:
"""Decode and validate the ``bridges`` section of an aggregation manifest."""
raw_bridges = payload.get("bridges", [])
if not isinstance(raw_bridges, list):
raise ValueError("Aggregation info must contain a 'bridges' list.")
specs = [RemoteBridgeSpec.from_dict(item) for item in raw_bridges]
seen = set()
for spec in specs:
if spec.idx in seen:
raise ValueError(
f"Aggregation info contains duplicate bridge idx {spec.idx}."
)
seen.add(spec.idx)
return specs
# ---------------------------------------------------------------------------
# Bridge-node entry points
# ---------------------------------------------------------------------------
[docs]
def run_bridge_node(info="aggregation.json", *, idx: int = 0) -> None:
"""
Run one bridge node from a manifest written by ``init_remote_bridges``.
Parameters
----------
info : str or path-like, default: ``"aggregation.json"``
JSON manifest written by :meth:`AggregatedSocketHub.init_remote_bridges`.
idx : int, default: 0
Zero-based bridge index identifying which bridge entry in ``info`` this
node should start.
"""
payload = _load_aggregation_info(info)
specs = _coerce_remote_bridge_specs(payload)
bridge_idx = int(idx)
try:
spec = next(spec for spec in specs if spec.idx == bridge_idx)
except StopIteration as exc:
available = ", ".join(str(spec.idx) for spec in specs) or "<none>"
raise IndexError(
f"Bridge idx {bridge_idx} not found in aggregation info. "
f"Available bridge indices: {available}."
) from exc
bridge = LocalSocketHubBridge(
group_id=spec.group_id,
upstream_host=str(payload["hub_host"]),
upstream_port=int(payload["hub_port"]),
timeout=float(payload.get("timeout", 60.0)),
latency=float(payload.get("latency", 0.01)),
local_unixsocket=spec.unixsocket,
)
thread = bridge.start()
try:
while thread.is_alive():
thread.join(timeout=1.0)
except KeyboardInterrupt:
pass
finally:
try:
bridge.stop(wait=max(2.0, 10.0 * bridge.latency))
except Exception:
pass
[docs]
def mxl_bridge_main(argv: list[str] | None = None) -> int:
"""
CLI entry point for running one aggregate bridge from a manifest.
Examples
--------
``mxl_bridge --info aggregation.json --idx 0``
"""
parser = argparse.ArgumentParser(
description="Run one MaxwellLink aggregate bridge node."
)
parser.add_argument(
"--info",
type=str,
default="aggregation.json",
help="Path to the aggregation manifest written by init_remote_bridges().",
)
parser.add_argument(
"--idx",
type=int,
default=0,
help="Zero-based bridge index within the aggregation manifest.",
)
args = parser.parse_args(argv)
run_bridge_node(info=args.info, idx=args.idx)
return 0
# ---------------------------------------------------------------------------
# Hub-owned convenience bridge handle
# ---------------------------------------------------------------------------
[docs]
class AggregatedBridge:
"""
Convenience handle for one hub-owned local bridge.
Instances of this class are returned by
:meth:`AggregatedSocketHub.add_bridge`. They provide a light wrapper around
:class:`LocalSocketHubBridge` so existing input scripts only need to:
1. create bridge handles from the hub,
2. attach molecules to a handle via :meth:`append`, and
3. launch downstream drivers against ``address``.
"""
[docs]
def __init__(
self,
*,
hub: "AggregatedSocketHub",
group_id: str,
bridge: "LocalSocketHubBridge",
):
self.hub = hub
self.group_id = str(group_id)
self._bridge = bridge
@property
def address(self) -> str:
"""Address string downstream UNIX-socket drivers should use."""
if self._bridge.local_unixsocket is None:
raise RuntimeError("This convenience bridge does not use a UNIX socket.")
return self._bridge.local_unixsocket
@property
def unixsocket(self) -> Optional[str]:
"""Configured UNIX-socket driver address, if any."""
return self._bridge.local_unixsocket
@property
def unixsocket_path(self) -> Optional[str]:
"""Resolved filesystem path for the local UNIX socket."""
return self._bridge.local_hub.unixsocket_path
@property
def local_endpoint(self) -> dict:
"""Return the downstream endpoint mapping for driver launch code."""
return dict(self._bridge.local_endpoint)
[docs]
def append(self, molecules) -> None:
"""
Attach one molecule or an iterable of molecules to this bridge group.
The helper only mutates ``molecule.init_payload["aggregate_group"]`` and
therefore works with existing ``mxl.Molecule`` / ``SocketMolecule``
objects without changing solver-side logic.
"""
for molecule in _as_molecule_list(molecules):
_assign_molecule_to_group(
molecule,
expected_hub=self.hub,
group_id=self.group_id,
)
[docs]
def start(self) -> threading.Thread:
"""Start the underlying local bridge thread."""
return self._bridge.start()
[docs]
def stop(self, wait: float = 2.0) -> None:
"""Stop the underlying local bridge."""
self._bridge.stop(wait=wait)
# ---------------------------------------------------------------------------
# EM-side aggregated hub
# ---------------------------------------------------------------------------
[docs]
class AggregatedSocketHub(SocketHub):
"""
EM-side hub that aggregates multiple molecule requests into one bridge link.
This class keeps the same public methods used by MaxwellLink solvers
(``register_molecule_return_id``, ``wait_until_bound``, ``all_bound``,
``step_barrier``) while mapping many molecule IDs onto a smaller number of
bridge connections.
Molecules are assigned to a bridge group through
``init_payload["aggregate_group"]``. All molecules sharing the same group
are sent together to one :class:`LocalSocketHubBridge`.
"""
[docs]
def __init__(
self,
host: Optional[str] = None,
port: Optional[int] = 31415,
timeout: float = 60000.0,
latency: float = 0.01,
):
super().__init__(
host=host,
port=port,
unixsocket=None,
timeout=timeout,
latency=latency,
)
self._groups: Dict[str, _AggregateGroupState] = {}
self._molecule_to_group: Dict[int, str] = {}
self._bridge_connect_host = (
"127.0.0.1" if host in (None, "", "0.0.0.0") else str(host)
)
self._bridge_connect_port = int(port or 31415)
self._owned_bridges: list[AggregatedBridge] = []
self._bridge_counter = 0
self.remote_bridges: list[RemoteBridgeSpec] = []
self.remote_bridge_info: Optional[dict] = None
self._bridge_selector = selectors.DefaultSelector()
# -- Bridge setup ------------------------------------------------------
[docs]
def add_bridge(self, local_unixsocket: str) -> AggregatedBridge:
"""
Create, start, and return one hub-owned local UNIX-socket bridge.
This is the convenience entry point intended for minimal edits when
migrating an existing single-layer ``SocketHub`` script to the new
two-layer transport.
"""
unix_name = str(local_unixsocket).strip()
if not unix_name:
raise ValueError("local_unixsocket must be a non-empty string.")
for handle in self._owned_bridges:
if handle.unixsocket == unix_name:
raise ValueError(
f"A bridge for local unix address {unix_name!r} already exists."
)
group_id = f"node-{self._bridge_counter}"
self._bridge_counter += 1
bridge = LocalSocketHubBridge(
group_id=group_id,
upstream_host=self._bridge_connect_host,
upstream_port=self._bridge_connect_port,
timeout=self.timeout,
latency=self.latency,
local_unixsocket=unix_name,
)
bridge.start()
handle = AggregatedBridge(hub=self, group_id=group_id, bridge=bridge)
self._owned_bridges.append(handle)
self._log(
f"STARTED: aggregate group {group_id!r} -> unix address {handle.address!r}"
)
if handle.unixsocket_path and handle.unixsocket_path != handle.address:
self._log(f"UNIX PATH: {handle.unixsocket_path}")
return handle
[docs]
def init_remote_bridges(
self,
molecules,
*,
molecules_per_bridge: int,
unix_prefix: str = "bridge_",
save_file: str = "aggregation.json",
) -> list[RemoteBridgeSpec]:
"""
Partition molecules across remote bridge groups and save a manifest.
This helper does not start any bridge threads locally. Instead it
assigns ``molecule.init_payload["aggregate_group"]`` for each molecule
and writes one JSON manifest that bridge-node scripts can consume via
:func:`run_bridge_node`.
Parameters
----------
molecules : molecule or iterable of molecules
Molecules to distribute across remote bridges.
molecules_per_bridge : int
Maximum number of molecules assigned to one bridge.
unix_prefix : str, default: ``"bridge_"``
Prefix used to generate downstream UNIX socket names
``f"{unix_prefix}{idx}"``.
save_file : str, default: ``"aggregation.json"``
Path where the bridge manifest should be written.
Returns
-------
list[RemoteBridgeSpec]
The generated bridge specifications in order.
"""
items = _as_molecule_list(molecules)
if not items:
raise ValueError("init_remote_bridges(...) requires at least one molecule.")
molecules_per_group = int(molecules_per_bridge)
if molecules_per_group <= 0:
raise ValueError("molecules_per_bridge must be a positive integer.")
prefix = str(unix_prefix)
specs: list[RemoteBridgeSpec] = []
for start in range(0, len(items), molecules_per_group):
idx = len(specs)
group_items = items[start : start + molecules_per_group]
unixsocket = f"{prefix}{idx}"
group_id = unixsocket
for molecule in group_items:
_assign_molecule_to_group(
molecule,
expected_hub=self,
group_id=group_id,
)
specs.append(
RemoteBridgeSpec(
idx=idx,
group_id=group_id,
unixsocket=unixsocket,
n_molecules=len(group_items),
)
)
payload = {
"version": AGGREGATION_INFO_VERSION,
"hub_host": self._bridge_connect_host,
"hub_port": self._bridge_connect_port,
"timeout": float(self.timeout),
"latency": float(self.latency),
"unix_prefix": prefix,
"molecules_per_bridge": molecules_per_group,
"bridges": [spec.to_dict() for spec in specs],
}
with open(os.fspath(save_file), "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2, sort_keys=True)
self.remote_bridges = list(specs)
self.remote_bridge_info = payload
self._log(
f"Prepared {len(specs)} remote aggregate bridge(s); "
f"manifest saved to {save_file!r}."
)
for spec in specs:
self._log(
f"REMOTE BRIDGE {spec.idx}: unix={spec.unixsocket!r} "
f"group={spec.group_id!r} molecules={spec.n_molecules}"
)
return specs
# -- Group bookkeeping -------------------------------------------------
def _deadline(self, timeout: Optional[float]) -> float:
"""Return an absolute monotonic-wall deadline for ``timeout`` seconds.
Falls back to the hub-wide ``self.timeout`` when ``timeout`` is ``None``.
"""
span = float(timeout) if timeout is not None else float(self.timeout)
return time.time() + span
def _extract_group_id(self, init_payload: Mapping, molecule_id: int) -> str:
"""Return the aggregate group for one molecule."""
group_id = init_payload.get("aggregate_group")
if group_id is None:
return f"molecule-{int(molecule_id)}"
group_id = str(group_id).strip()
if not group_id:
raise ValueError(
f"aggregate_group for molecule {int(molecule_id)} must be non-empty"
)
return group_id
def _prepare_groups_locked(self, init_payloads: Mapping[int, dict]) -> None:
"""Build or update aggregate group metadata from solver INIT payloads."""
for mid, raw_payload in init_payloads.items():
molid = int(mid)
payload = {**dict(raw_payload), "molecule_id": molid}
previous = self._molecule_to_group.get(molid)
if "aggregate_group" not in payload and previous is not None:
group_id = previous
else:
group_id = self._extract_group_id(payload, molid)
if previous is not None and previous != group_id:
raise ValueError(
f"Molecule {molid} was already assigned to aggregate_group "
f"{previous!r}, cannot reassign it to {group_id!r}."
)
self._molecule_to_group[molid] = group_id
group = self._groups.setdefault(group_id, _AggregateGroupState(group_id))
group.init_payloads[molid] = payload
if molid not in group.molecule_ids:
group.molecule_ids.append(molid)
def _group_and_bridge(
self, group_id: str
) -> tuple[Optional[_AggregateGroupState], Optional[_ClientState]]:
"""Return ``(group, bridge)`` for ``group_id`` under the hub lock."""
with self._lock:
group = self._groups.get(group_id)
st = None if group is None else group.bridge
return group, st
# -- Bridge socket registration ---------------------------------------
def _register_bridge_sock(self, sock: socket.socket, group_id: str) -> None:
"""Register a bridge socket for readable events with its group id."""
self._unregister_bridge_sock(sock)
try:
self._bridge_selector.register(sock, selectors.EVENT_READ, data=group_id)
except _SELECTOR_ERRORS:
pass
def _unregister_bridge_sock(self, sock: socket.socket) -> None:
"""Unregister a bridge socket from the aggregate selector."""
try:
self._bridge_selector.unregister(sock)
except _SELECTOR_ERRORS:
pass
def _detach_sock_locked(self, st: _ClientState) -> None:
"""Unregister a client from both selectors and mark it dead.
Caller must hold ``self._lock`` and is responsible for closing the
socket afterwards (typically outside the lock).
"""
self._unregister_bridge_sock(st.sock)
self._unregister_sock(st.sock)
st.alive = False
st.initialized = False
def _bind_group_locked(self, group_id: str, st_key, st: _ClientState) -> None:
"""Attach one accepted bridge socket to a configured group."""
group = self._groups[group_id]
group.bridge = st
st.molecule_id = group.molecule_ids[0] if group.molecule_ids else -1
st.initialized = False
st.extras["aggregate_group"] = group_id
self.clients[group_id] = st
if st_key != group_id:
self.clients.pop(st_key, None)
for mid in group.molecule_ids:
self.bound[mid] = st
self._register_bridge_sock(st.sock, group_id)
self._log(f"CONNECTED: aggregate group {group_id!r} <- {st.address}")
def _drop_client_locked(self, st_key, st: _ClientState, reason: str) -> None:
"""Remove a temporary or duplicate bridge client."""
self.clients.pop(st_key, None)
self._detach_sock_locked(st)
_close_socket(st.sock)
self._log(f"DROPPED ({reason}): {st.address}")
def _mark_group_dead(self, group_id: str, reason: str) -> None:
"""Mark an aggregate group as disconnected and clear all molecule bindings."""
with self._lock:
group = self._groups.get(group_id)
if group is None or group.bridge is None:
return
st = group.bridge
group.bridge = None
self.clients.pop(group_id, None)
self._detach_sock_locked(st)
for mid in group.molecule_ids:
if self.bound.get(mid) is st:
self.bound[mid] = None
self._log(f"DISCONNECTED ({reason}): aggregate group {group_id!r}")
_close_socket(st.sock)
self._pause()
# -- Binding handshake -------------------------------------------------
def _snapshot_unbound_clients(self, *, identified: bool) -> list:
"""Snapshot still-unbound bridge clients under the hub lock.
``identified`` selects clients that have already announced their
``aggregate_group`` via HELLO (``True``) versus those still awaiting it
(``False``). Returns a list of ``(client_key, client_state)`` pairs.
"""
with self._lock:
return [
(st_key, st)
for st_key, st in list(self.clients.items())
if st is not None
and st.alive
and st.molecule_id < 0
and ("aggregate_group" in st.extras) == identified
]
def _try_identify_fresh_clients(self) -> None:
"""
Poll newly accepted sockets for bridge HELLO messages.
A bridge sends HELLO immediately after connecting. We keep the read
timeout short here so one slow client cannot stall the entire hub.
"""
for st_key, st in self._snapshot_unbound_clients(identified=False):
try:
msg = _recv_msg_with_timeout(st.sock, max(self.latency, 0.05))
except socket.timeout:
continue
except (RuntimeError, _SocketClosed, OSError):
with self._lock:
self._drop_client_locked(st_key, st, reason="hello")
continue
if msg != AGGHELLO:
with self._lock:
self._drop_client_locked(st_key, st, reason="hello-header")
continue
try:
hello = _json_loads_bytes(_recv_bytes(st.sock))
except (RuntimeError, _SocketClosed, OSError):
with self._lock:
self._drop_client_locked(st_key, st, reason="hello-payload")
continue
group_id = str(hello.get("group_id", "")).strip()
if not group_id:
with self._lock:
self._drop_client_locked(st_key, st, reason="hello-group")
continue
with self._lock:
st.extras["aggregate_group"] = group_id
def _progress_group_binds(self) -> None:
"""Bind identified bridge clients to configured groups whenever possible."""
with self._lock:
for st_key, st in self._snapshot_unbound_clients(identified=True):
group_id = st.extras["aggregate_group"]
group = self._groups.get(group_id)
if group is None:
continue
if group.bridge is None:
self._bind_group_locked(group_id, st_key, st)
elif group.bridge is not st:
self._drop_client_locked(st_key, st, reason="duplicate-group")
def _initialize_group(self, group_id: str) -> bool:
"""Send AGGINIT to a bound bridge and wait for AGGREADY."""
with self._lock:
group = self._groups[group_id]
st = group.bridge
init_payloads = dict(group.init_payloads)
if st is None or not st.alive:
return False
try:
_send_aggregate_init(
st.sock,
group_id=group_id,
init_payloads=init_payloads,
)
msg = _recv_msg_with_timeout(st.sock, self.timeout)
if msg != AGGREADY:
raise RuntimeError(f"Expected {AGGREADY!r}, got {msg!r}")
except (socket.timeout, RuntimeError, _SocketClosed, OSError):
self._mark_group_dead(group_id, reason="init")
return False
with self._lock:
if group.bridge is st and st.alive:
st.initialized = True
return True
return False
[docs]
def wait_until_bound(self, init_payloads: dict, require_init=True, timeout=None):
"""
Wait until all requested molecules are served by initialized bridges.
Molecules are grouped through ``init_payload["aggregate_group"]`` and each
group must be backed by exactly one connected bridge.
"""
wanted = {int(mid) for mid in init_payloads.keys()}
deadline = self._deadline(timeout)
payloads = {
int(mid): {**dict(init_payloads[mid]), "molecule_id": int(mid)}
for mid in init_payloads.keys()
}
with self._lock:
self._prepare_groups_locked(payloads)
while True:
if self.all_bound(wanted, require_init=require_init):
self._resume()
return True
self._try_identify_fresh_clients()
self._progress_group_binds()
with self._lock:
groups_needing_init = [
group_id
for group_id, group in self._groups.items()
if any(mid in wanted for mid in group.molecule_ids)
and group.bridge is not None
and group.bridge.alive
and not group.bridge.initialized
]
for group_id in groups_needing_init:
self._initialize_group(group_id)
if timeout is not None and time.time() > deadline:
return False
time.sleep(self.latency)
# -- Stepping ----------------------------------------------------------
def _plan_step_locked(
self, requests: Dict[int, dict]
) -> Optional[Dict[str, Dict[int, dict]]]:
"""
Validate a step request and group it by aggregate group.
Returns a ``{group_id: {molecule_id: {"efield_au": ndarray}}}`` mapping,
or ``None`` if not every requested molecule is bound and initialized.
Must be called while holding ``self._lock``.
"""
wants = {int(mid) for mid in requests.keys()}
needs_prepare = any("init" in requests[mid] for mid in requests.keys()) or any(
int(mid) not in self._molecule_to_group for mid in requests.keys()
)
if needs_prepare:
payloads = {
int(mid): dict(requests[mid].get("init") or {"molecule_id": int(mid)})
for mid in requests.keys()
}
self._prepare_groups_locked(payloads)
if not self.all_bound(wants, require_init=True):
return None
grouped_requests: Dict[str, Dict[int, dict]] = {}
for mid in wants:
group_id = self._molecule_to_group[mid]
grouped_requests.setdefault(group_id, {})[mid] = {
"efield_au": np.asarray(requests[mid]["efield_au"], dtype=float)
}
return grouped_requests
def _send_step_to_group(
self, group_id: str, group_request: Dict[int, dict]
) -> bool:
"""Send one grouped fan-out step; return ``False`` on a dead bridge."""
group, st = self._group_and_bridge(group_id)
if group is None or st is None or not st.alive or not st.initialized:
self._pause()
return False
try:
group.step_codec.send(st.sock, group_request)
except (socket.timeout, _SocketClosed, OSError):
self._mark_group_dead(group_id, reason="send")
return False
return True
def _collect_group_result(
self, group_id: str, expected_ids: set[int], deadline: float
) -> Optional[Dict[int, dict]]:
"""
Receive and validate one group's reply.
Returns the per-molecule responses, or ``None`` if the bridge died or
the deadline passed (failure side effects are handled internally).
Raises ``RuntimeError`` if the bridge returns the wrong molecule ids.
"""
group, st = self._group_and_bridge(group_id)
if group is None or st is None or not st.alive:
self._pause()
return None
remaining = deadline - time.time()
if remaining <= 0.0:
return None
old_timeout = st.sock.gettimeout()
try:
st.sock.settimeout(max(0.0, remaining))
group_responses = group.result_codec.recv(st.sock)
except (socket.timeout, RuntimeError, _SocketClosed, OSError):
self._mark_group_dead(group_id, reason="recv")
return None
finally:
st.sock.settimeout(old_timeout)
actual = set(group_responses.keys())
if actual != expected_ids:
self._mark_group_dead(group_id, reason="protocol")
raise RuntimeError(
f"Aggregate group {group_id!r} returned molecule ids {sorted(actual)}, "
f"expected {sorted(expected_ids)}."
)
return group_responses
[docs]
def step_barrier(
self, requests: Dict[int, dict], timeout: Optional[float] = None
) -> Dict[int, dict]:
"""
Dispatch all requested fields group-by-group and collect grouped replies.
The caller-facing contract matches ``SocketHub.step_barrier``:
``responses[molid]`` contains ``{"amp": ndarray(3,), "extra": bytes}``.
"""
if self.paused:
return {}
deadline = self._deadline(timeout)
with self._lock:
grouped_requests = self._plan_step_locked(requests)
if not grouped_requests:
return {}
for group_id, group_request in grouped_requests.items():
if not self._send_step_to_group(group_id, group_request):
return {}
responses: Dict[int, dict] = {}
pending_groups = set(grouped_requests.keys())
# Fast path: a single group needs only a blocking recv, no selector.
if len(pending_groups) == 1:
group_id = next(iter(pending_groups))
expected = set(grouped_requests[group_id].keys())
group_responses = self._collect_group_result(group_id, expected, deadline)
if group_responses is None:
return {}
responses.update(group_responses)
return responses
# Multiple groups: wait on whichever bridge becomes readable next.
while pending_groups:
remaining = deadline - time.time()
if remaining <= 0.0:
return {}
try:
events = self._bridge_selector.select(timeout=min(remaining, 1.0))
except OSError:
return {}
if not events:
continue
for key, _mask in events:
group_id = key.data
if group_id not in pending_groups:
continue
expected = set(grouped_requests[group_id].keys())
group_responses = self._collect_group_result(
group_id, expected, deadline
)
if group_responses is None:
return {}
responses.update(group_responses)
pending_groups.discard(group_id)
return responses
# -- Shutdown ----------------------------------------------------------
def _snapshot_stop_targets(self):
"""
Snapshot the bridge groups and any stray clients to tear down.
Returns ``(group_clients, other_clients)`` where ``group_clients`` is a
list of ``(group_id, bridge_state, molecule_ids)`` and ``other_clients``
is a list of ``(client_key, client_state)`` not already covered above.
"""
with self._lock:
group_clients = [
(group_id, group.bridge, list(group.molecule_ids))
for group_id, group in self._groups.items()
if group.bridge is not None
]
seen = {id(st) for _, st, _ in group_clients}
other_clients = []
for key, st in list(self.clients.items()):
if st is None or id(st) in seen:
continue
seen.add(id(st))
other_clients.append((key, st))
return group_clients, other_clients
def _request_bridge_shutdown(self, group_clients) -> None:
"""Send ``STOP`` to every live bridge group."""
for _group_id, st, _molecule_ids in group_clients:
if not st.alive:
continue
try:
_send_msg(st.sock, STOP)
except Exception:
pass
def _await_bridge_byes(self, group_clients) -> None:
"""Wait briefly for each bridge to acknowledge ``STOP`` with ``BYE``."""
deadline = time.time() + max(2.0, 10.0 * self.latency)
while time.time() < deadline:
remaining_alive = False
for _group_id, st, _molecule_ids in group_clients:
if not st.alive:
continue
remaining_alive = True
try:
st.sock.settimeout(self.latency)
if _recv_msg(st.sock) == BYE:
st.alive = False
except (socket.timeout, _SocketClosed, OSError):
continue
if not remaining_alive:
break
time.sleep(self.latency)
def _teardown_stop_targets(self, group_clients, other_clients) -> None:
"""Clear all bridge/molecule state and close every snapshotted socket."""
sockets_to_close = []
with self._lock:
for group_id, st, molecule_ids in group_clients:
group = self._groups.get(group_id)
if group is not None and group.bridge is st:
group.bridge = None
self.clients.pop(group_id, None)
self._detach_sock_locked(st)
for mid in molecule_ids:
if self.bound.get(mid) is st:
self.bound[mid] = None
sockets_to_close.append((f"aggregate group {group_id!r}", st.sock))
for key, st in other_clients:
self.clients.pop(key, None)
self._detach_sock_locked(st)
sockets_to_close.append((f"client {key!r}", st.sock))
for label, sock in sockets_to_close:
self._log(f"DISCONNECTED: {label}")
_close_socket(sock)
[docs]
def stop(self):
"""
Stop the aggregate hub and clean up bridge groups coherently.
The base ``SocketHub.stop()`` assumes one client per molecule, which is
not true here. This override shuts down each bridge once and clears all
molecule bindings associated with that bridge.
"""
owned_bridges = list(self._owned_bridges)
self._stop = True
try:
self.serversock.close()
except Exception:
pass
group_clients, other_clients = self._snapshot_stop_targets()
self._request_bridge_shutdown(group_clients)
self._await_bridge_byes(group_clients)
self._teardown_stop_targets(group_clients, other_clients)
for selector in (self._selector, self._bridge_selector):
try:
selector.close()
except Exception:
pass
for handle in owned_bridges:
try:
handle.stop(wait=max(2.0, 10.0 * self.latency))
except Exception:
pass
# ---------------------------------------------------------------------------
# Node-local bridge
# ---------------------------------------------------------------------------
[docs]
class LocalSocketHubBridge:
"""
Bridge process/thread that fans out aggregate requests to a local SocketHub.
Upstream:
one TCP connection to :class:`AggregatedSocketHub`
Downstream:
one ordinary :class:`SocketHub` using either TCP or UNIX sockets,
connected to many existing MaxwellLink socket drivers.
"""
[docs]
def __init__(
self,
*,
group_id: str,
upstream_host: str,
upstream_port: int,
timeout: float = 60.0,
latency: float = 0.01,
local_host: str = "127.0.0.1",
local_port: Optional[int] = None,
local_unixsocket: Optional[str] = None,
):
if not str(group_id).strip():
raise ValueError("group_id must be a non-empty string")
if local_unixsocket is None and local_port is None:
sanitized = "".join(
ch if ch.isalnum() or ch in ("-", "_") else "_" for ch in str(group_id)
).strip("_")
local_unixsocket = f"agg_{sanitized or 'bridge'}"
self.group_id = str(group_id)
self.upstream_host = str(upstream_host)
self.upstream_port = int(upstream_port)
self.timeout = float(timeout)
self.latency = float(latency)
self.local_host = str(local_host)
self.local_port = int(local_port) if local_port is not None else None
self.local_unixsocket = local_unixsocket
self.local_hub = SocketHub(
host=self.local_host if self.local_unixsocket is None else None,
port=self.local_port,
unixsocket=self.local_unixsocket,
timeout=self.timeout,
latency=self.latency,
)
self._init_payloads: Dict[int, dict] = {}
self._request_cache: Dict[int, dict] = {}
self._upstream_sock: Optional[socket.socket] = None
self._thread: Optional[threading.Thread] = None
self._stop_event = threading.Event()
self._step_codec = _StepCodec()
self._result_codec = _ResultCodec()
@property
def local_endpoint(self) -> dict:
"""Return the downstream socket endpoint local drivers should connect to."""
if self.local_unixsocket is not None:
return {"unixsocket": self.local_unixsocket}
return {"host": self.local_host, "port": self.local_port}
def _ensure_local_hub_ready(self, init_payloads: Mapping[int, dict]) -> None:
"""Register downstream molecule ids and wait until local drivers bind."""
for mid in init_payloads.keys():
try:
self.local_hub.register_molecule(int(mid))
except ValueError:
pass
ok = self.local_hub.wait_until_bound(
dict(init_payloads),
require_init=True,
timeout=None,
)
if not ok:
raise RuntimeError(
f"Timed out waiting for local drivers in aggregate group {self.group_id!r}"
)
def _handle_group_init(self, payload: dict) -> None:
"""Accept a new group membership assignment from the upstream hub."""
incoming_group = str(payload.get("group_id", "")).strip()
if incoming_group != self.group_id:
raise RuntimeError(
f"Bridge {self.group_id!r} received AGGINIT for group {incoming_group!r}."
)
init_payloads = {
int(mid): {**dict(data), "molecule_id": int(mid)}
for mid, data in payload["init_payloads"].items()
}
self._ensure_local_hub_ready(init_payloads)
self._init_payloads = init_payloads
self._request_cache = {
int(mid): {"efield_au": np.zeros(3, dtype=float)}
for mid in init_payloads.keys()
}
def _build_local_requests(
self, efields: Mapping[int, np.ndarray]
) -> Dict[int, dict]:
"""
Map upstream efields onto the reusable downstream request cache.
When the requested molecule set matches the cache exactly we update the
cached arrays in place; otherwise we patch/extend the cache as needed.
"""
cache_hit = (
len(efields) == len(self._request_cache)
and self._request_cache
and all(int(mid) in self._request_cache for mid in efields.keys())
)
if cache_hit:
for mid, efield in efields.items():
np.copyto(
self._request_cache[int(mid)]["efield_au"],
np.asarray(efield, dtype=float).reshape(3),
)
return self._request_cache
requests: Dict[int, dict] = {}
for mid, efield in efields.items():
molid = int(mid)
cached = self._request_cache.get(molid)
if cached is None:
cached = {"efield_au": np.zeros(3, dtype=float)}
self._request_cache[molid] = cached
np.copyto(cached["efield_au"], np.asarray(efield, dtype=float).reshape(3))
requests[molid] = cached
return requests
def _run_local_step(self, efields: Mapping[int, np.ndarray]) -> Dict[int, dict]:
"""Fan out one grouped step to the downstream local hub."""
requests = self._build_local_requests(efields)
responses = self.local_hub.step_barrier(requests)
while not responses:
self._ensure_local_hub_ready(self._init_payloads)
responses = self.local_hub.step_barrier(requests)
return responses
[docs]
def run(self) -> None:
"""Run the bridge loop until the hub sends ``STOP`` or disconnects."""
sock = _connect_tcp_with_retry(
address=self.upstream_host,
port=self.upstream_port,
timeout=self.timeout,
)
self._upstream_sock = sock
_send_aggregate_hello(sock, group_id=self.group_id)
try:
while not self._stop_event.is_set():
msg = _recv_msg(sock)
if msg == AGGINIT:
payload = _json_loads_bytes(_recv_bytes(sock))
payload["init_payloads"] = {
int(mid): dict(data)
for mid, data in payload.get("init_payloads", {}).items()
}
self._handle_group_init(payload)
_send_msg(sock, AGGREADY)
elif msg == AGGSTEP:
efields = self._step_codec.recv(sock, header_already_read=True)
responses = self._run_local_step(efields)
self._result_codec.send(sock, responses)
elif msg == STOP:
try:
_send_msg(sock, BYE)
except OSError:
pass
break
else:
raise RuntimeError(f"Unexpected aggregate header: {msg!r}")
except (socket.timeout, _SocketClosed, OSError):
pass
finally:
_close_socket(sock)
self._upstream_sock = None
self.local_hub.stop()
[docs]
def start(self) -> threading.Thread:
"""Start the bridge loop in a daemon thread and return the thread handle."""
if self._thread is not None and self._thread.is_alive():
return self._thread
self._thread = threading.Thread(target=self.run, daemon=True)
self._thread.start()
return self._thread
[docs]
def stop(self, wait: float = 2.0) -> None:
"""Stop the bridge loop and close the downstream local hub."""
self._stop_event.set()
sock = self._upstream_sock
if sock is not None:
try:
sock.shutdown(socket.SHUT_RDWR)
except OSError:
pass
_close_socket(sock)
if self._thread is not None:
self._thread.join(timeout=float(wait))
self.local_hub.stop()
__all__ = [
"AggregatedBridge",
"AggregatedSocketHub",
"LocalSocketHubBridge",
"RemoteBridgeSpec",
"mxl_bridge_main",
"run_bridge_node",
]