Source code for maxwelllink.sockets.sockets

# --------------------------------------------------------------------------------------#
# Copyright (c) 2026 MaxwellLink                                                       #
# This file is part of MaxwellLink. Repository: https://github.com/TaoELi/MaxwellLink  #
# If you use this code, always credit and cite arXiv:2512.06173.                       #
# See AGENTS.md and README.md for details.                                             #
# --------------------------------------------------------------------------------------#

"""
Socket layer for MaxwellLink drivers and servers.

This module implements a lightweight socket protocol inspired by i-PI
(https://ipi-code.org/) and provides:

- **SocketHub**: a multi-client server/poller for coordinating many driver
  connections with an FDTD engine.
- **Protocol constants**: ``STATUS``, ``READY``, ``HAVEDATA``, ``NEEDINIT``,
  ``INIT``, ...
- **EM aliases**: ``FIELDDATA``, ``GETSOURCE``, ``SOURCEREADY`` (1:1 mapping to
  ``POSDATA``/``GETFORCE``/``FORCEREADY``).
- **Low-level helpers**: ``_send_msg``, ``_recv_msg``, ``_send_array``/``_recv_array``,
  etc.
- **Exceptions**: ``_SocketClosed``.
"""

from __future__ import annotations

import json
import os
import selectors
import socket
import struct
import threading
import time
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple

import numpy as np

# ======================================================================
# Protocol constants and wire dtypes
# ======================================================================

_INT32 = struct.Struct("<i")
_FLOAT64 = struct.Struct("<d")

# Fixed header width (ASCII, space-padded)
HEADER_LEN = 12
# Canonical i-PI message codes
STATUS = b"STATUS"
READY = b"READY"
HAVEDATA = b"HAVEDATA"
NEEDINIT = b"NEEDINIT"
INIT = b"INIT"
POSDATA = b"POSDATA"
GETFORCE = b"GETFORCE"
FORCEREADY = b"FORCEREADY"
STOP = b"STOP"
BYE = b"BYE"

# EM aliases for readability (same wire format)
FIELDDATA = POSDATA
GETSOURCE = GETFORCE
SOURCEREADY = FORCEREADY

# numpy dtypes on the wire (i-PI/ASE use float64 for reals, int32 for counts)
DT_FLOAT = np.float64
DT_INT = np.int32


class _SocketClosed(OSError):
    """
    Exception raised when the peer closes the socket unexpectedly.
    """

    pass


# ======================================================================
# Low-level wire helpers (headers, ints, arrays, byte strings)
# ======================================================================


def _pad12(msg: bytes) -> bytes:
    """
    Pad a message to the fixed 12-byte ASCII header width.

    Parameters
    ----------
    msg : bytes
        Message tag to send.

    Returns
    -------
    bytes
        Space-padded header of exactly 12 bytes.

    Raises
    ------
    ValueError
        If ``msg`` exceeds the 12-byte header length.
    """

    if len(msg) > HEADER_LEN:
        raise ValueError("Header too long")
    return msg.ljust(HEADER_LEN, b" ")


def _send_msg(sock: socket.socket, msg: bytes) -> None:
    """
    Send a 12-byte ASCII header (space-padded).

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    msg : bytes
        Message tag to send (e.g., ``b"STATUS"``).
    """

    sock.sendall(_pad12(msg))


def _recvall(sock: socket.socket, n: int) -> bytes:
    """
    Read exactly ``n`` bytes from a socket.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    n : int
        Number of bytes to read.

    Returns
    -------
    bytes
        The data read.

    Raises
    ------
    _SocketClosed
        If the peer closes the connection before all bytes are received.
    """

    buf = bytearray()
    while len(buf) < n:
        chunk = sock.recv(n - len(buf))
        if not chunk:
            raise _SocketClosed("Peer closed")
        buf.extend(chunk)
    return bytes(buf)


def _recv_msg(sock: socket.socket) -> bytes:
    """
    Receive a 12-byte ASCII header.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.

    Returns
    -------
    bytes
        The received header without trailing spaces.
    """

    hdr = _recvall(sock, HEADER_LEN)
    return hdr.rstrip()


def _send_array(sock: socket.socket, arr, dtype) -> None:
    """
    Send a NumPy array over a socket using a contiguous C-order memory view.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    arr : array-like
        Array data to send.
    dtype : numpy.dtype
        Data type to cast and send as (e.g., ``np.float64``).
    """

    a = np.asarray(arr, dtype=dtype, order="C")
    sock.sendall(memoryview(a).cast("B"))


def _recv_array(sock: socket.socket, shape, dtype):
    """
    Receive a NumPy array of a given shape and dtype from a socket.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    shape : tuple of int
        Expected array shape.
    dtype : numpy.dtype
        Expected dtype (e.g., ``np.float64``).

    Returns
    -------
    numpy.ndarray
        The received array with the specified shape and dtype.

    Raises
    ------
    _SocketClosed
        If the peer closes the connection during the transfer.
    """

    out = np.empty(shape, dtype=dtype, order="C")
    mv = memoryview(out).cast("B")
    need = mv.nbytes
    got = 0
    while got < need:
        r = sock.recv_into(mv[got:], need - got)
        if r == 0:
            raise _SocketClosed("Peer closed")
        got += r
    return out


def _send_int(sock: socket.socket, x: int) -> None:
    """
    Send a 32-bit little-endian integer.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    x : int
        Integer value to send.
    """

    sock.sendall(_INT32.pack(int(x)))


def _recv_int(sock: socket.socket) -> int:
    """
    Receive a 32-bit little-endian integer.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.

    Returns
    -------
    int
        The received integer.

    Raises
    ------
    _SocketClosed
        If the peer closes the connection during the transfer.
    """

    buf = bytearray(_INT32.size)
    mv = memoryview(buf)
    got = 0
    while got < _INT32.size:
        r = sock.recv_into(mv[got:], _INT32.size - got)
        if r == 0:
            raise _SocketClosed("Peer closed")
        got += r
    return _INT32.unpack(buf)[0]


def _send_bytes(sock: socket.socket, b: bytes) -> None:
    """
    Send a length-prefixed byte string.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    b : bytes
        Byte string to send. The length is sent first as a 32-bit integer.
    """

    _send_int(sock, len(b))
    if len(b):
        sock.sendall(b)


def _recv_bytes(sock: socket.socket) -> bytes:
    """
    Receive a length-prefixed byte string.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.

    Returns
    -------
    bytes
        The received byte string (may be empty).
    """

    n = _recv_int(sock)
    return _recvall(sock, n) if n else b""


# ======================================================================
# Compound payload codecs (i-PI compatible)
# ======================================================================


def _recv_posdata(sock: socket.socket):
    """
    Read a POSDATA/FIELDDATA block.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.

    Returns
    -------
    tuple
        ``(cell, icell, xyz)`` where:

        - ``cell`` : ``(3, 3)`` ndarray (row-major), simulation cell.
        - ``icell`` : ``(3, 3)`` ndarray (row-major), inverse cell.
        - ``xyz`` : ``(nat, 3)`` ndarray of positions (or effective field payload).
    """

    cell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
    icell = _recv_array(sock, (3, 3), DT_FLOAT).T.copy()
    nat = _recv_int(sock)
    xyz = _recv_array(sock, (nat, 3), DT_FLOAT)
    return cell, icell, xyz


def _send_force_ready(
    sock: socket.socket,
    energy_ha: float,
    forces_Nx3_ha_per_bohr,
    virial_3x3_ha,
    more: bytes = b"",
):
    """
    Send a FORCEREADY/SOURCEREADY message with energy, forces, virial, and extras.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    energy_ha : float
        Total energy (Hartree).
    forces_Nx3_ha_per_bohr : array-like, shape (N, 3)
        Forces (Hartree/Bohr).
    virial_3x3_ha : array-like, shape (3, 3)
        Virial tensor (Hartree).
    more : bytes, optional
        Extra payload (length-prefixed), e.g., JSON metadata.
    """

    _send_msg(sock, FORCEREADY)
    _send_array(sock, np.array([energy_ha], dtype=DT_FLOAT), DT_FLOAT)
    forces = np.asarray(forces_Nx3_ha_per_bohr, dtype=DT_FLOAT)
    assert forces.ndim == 2 and forces.shape[1] == 3
    _send_int(sock, forces.shape[0])
    _send_array(sock, forces, DT_FLOAT)
    _send_array(sock, np.asarray(virial_3x3_ha, dtype=DT_FLOAT).T, DT_FLOAT)
    _send_bytes(sock, more)


# ======================================================================
# EM convenience wrappers (i-PI compatible)
# ======================================================================


def _pack_init(sock: socket.socket, init_dict: dict):
    """
    Send an INIT handshake containing a JSON payload.

    Parameters
    ----------
    sock : socket.socket
        Connected socket.
    init_dict : dict
        Initialization dictionary (e.g., includes ``"molecule_id"``).
    """

    _send_msg(sock, INIT)
    molid = int(init_dict.get("molecule_id", 0))
    _send_int(sock, molid)
    init_bytes = json.dumps(init_dict).encode("utf-8")
    _send_bytes(sock, init_bytes)


# ======================================================================
# Fast-path constants and send/recv layout
# ======================================================================

_FIELDDATA_HDR = _pad12(FIELDDATA)
_GETSOURCE_HDR = _pad12(GETSOURCE)
_EYE3_BYTES = bytes(
    memoryview(np.ascontiguousarray(np.eye(3, dtype=DT_FLOAT))).cast("B")
)
_NAT1_BYTES = _INT32.pack(1)


# --------- fast-path send/recv layout ---------
#
# Send blob (196 bytes; written in place into a reusable bytearray):
#   [0  :12 ] FIELDDATA header
#   [12 :84 ] cell (3x3 float64, identity)
#   [84 :156] invcell (3x3 float64, identity)
#   [156:160] nat (int32 = 1)
#   [160:184] field vector (3 x float64)     <-- only this window changes
#   [184:196] GETSOURCE header
#
# Fixed reply (124 bytes; read into a reusable bytearray via recv_into):
#   [0  :12 ] SOURCEREADY header
#   [12 :20 ] energy (float64)
#   [20 :24 ] nat (int32, expected = 1)
#   [24 :48 ] forces (1 x 3 float64)
#   [48 :120] virial (3x3 float64)
#   [120:124] extra_len (int32)
#   (followed by `extra_len` trailing bytes of JSON/etc., read separately)

_SEND_FIELD_OFFSET = 12 + 72 + 72 + 4  # = 160
_SEND_TOTAL_LEN = _SEND_FIELD_OFFSET + 24 + 12  # = 196
_SEND_TEMPLATE = (
    _FIELDDATA_HDR
    + _EYE3_BYTES
    + _EYE3_BYTES
    + _NAT1_BYTES
    + b"\x00" * 24
    + _GETSOURCE_HDR
)
assert len(_SEND_TEMPLATE) == _SEND_TOTAL_LEN

_REPLY_FIXED_LEN = 12 + 8 + 4 + 24 + 72 + 4  # = 124
_REPLY_NAT_OFFSET = 12 + 8  # = 20
_REPLY_FORCES_OFFSET = 12 + 8 + 4  # = 24
_REPLY_EXTRA_LEN_OFFSET = 12 + 8 + 4 + 24 + 72  # = 120

_STRUCT_3D = struct.Struct("<3d")
_STRUCT_I = struct.Struct("<i")


# ======================================================================
# Module-level utilities (host/port discovery and MPI helpers)
# ======================================================================


[docs] def get_available_host_port(localhost=True, save_to_file=None) -> Tuple[str, int]: """ Ask the OS for an available localhost TCP port. Parameters ---------- localhost : bool, default: True If True, bind to the localhost interface ("127.0.0.1"). If False, bind to all interfaces ("0.0.0.0"). save_to_file : str or None, default: None If provided, save the selected host and port to the given file with filename provided by `save_to_file`. The first line contains the host, and the second line contains the port. Returns ------- tuple ``(host, port)`` pair, e.g., ``("127.0.0.1", 34567)``. """ bind_addr = "127.0.0.1" if localhost else "0.0.0.0" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind((bind_addr, 0)) port = s.getsockname()[1] ip = "127.0.0.1" if not localhost: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as tmp: tmp.connect(("8.8.8.8", 80)) ip = tmp.getsockname()[0] if am_master(): # save host and port number to a file so mxl_driver can read it if save_to_file is not None: with open(save_to_file, "w") as f: f.write(f"{ip}\n{port}\n") return ip, port
def _mpi_comm(): """ Return ``MPI.COMM_WORLD`` if ``mpi4py`` is importable, otherwise ``None``. Centralizes the optional-dependency import so the MPI helpers below can treat "no mpi4py" as a single-process (rank 0) world. """ try: from mpi4py import MPI return MPI.COMM_WORLD except Exception: return None # helper function to determine whether this processor is the MPI master using mpi4py
[docs] def am_master(): """ Return True if this process is the MPI master rank (rank 0), otherwise False. Notes ----- Attempts to import ``mpi4py`` and query ``COMM_WORLD``. If unavailable, returns ``True`` by treating the single process as rank 0. """ comm = _mpi_comm() rank = comm.Get_rank() if comm is not None else 0 return rank == 0
# helper function to broadcast a value from master to all MPI ranks
[docs] def mpi_bcast_from_master(value): """ Broadcast a Python value from the master rank to all ranks via MPI. Parameters ---------- value : any The value to broadcast. Returns ------- any The broadcast value (unchanged when MPI is unavailable). """ comm = _mpi_comm() if comm is not None: value = comm.bcast(value, root=0) return value
# ====================================================================== # Per-client state and the socket hub # ====================================================================== @dataclass class _ClientState: """ Dataclass storing per-client state for the socket hub. Attributes ---------- sock : socket.socket Connected client socket. address : str Peer address string. molecule_id : int Bound molecule identifier (``-1`` if unbound). last_amp : numpy.ndarray or None Last source amplitude vector ``(3,)``. pending_send : bool Whether a field has been dispatched but not yet committed. initialized : bool Whether INIT has been completed. alive : bool Connection liveness flag. extras : dict Arbitrary metadata associated with the client. """ sock: socket.socket address: str molecule_id: int last_amp: Optional[np.ndarray] = None # last source amplitude (3,) pending_send: bool = False initialized: bool = False alive: bool = True extras: dict = field(default_factory=dict)
[docs] class SocketHub: """ Socket server coordinating multiple driver connections with an FDTD engine. This server: - Accepts and tracks many driver connections. - Handles initialization handshakes, field dispatch, and result collection. - Provides a barrier-style step to send fields and receive source amplitudes from all registered molecules. """
[docs] def __init__( self, host: Optional[str] = None, port: Optional[int] = 31415, unixsocket: Optional[str] = None, timeout: float = 60000.0, latency: float = 0.01, ): """ Initialize the socket hub. Parameters ---------- host : str or None, default: None Host address for AF_INET sockets. Ignored when using a UNIX socket. port : int or None, default: 31415 TCP port for AF_INET sockets. Ignored for UNIX sockets. unixsocket : str or None, default: None Path (or name under ``/tmp/socketmxl_*``) for a UNIX domain socket. When provided, ``host`` and ``port`` are ignored. timeout : float, default: 60000.0 Socket timeout (seconds) for client operations. latency : float, default: 0.01 Polling sleep (seconds) between hub sweeps; can be very small for local runs. """ self.unixsocket_path = None if am_master(): if unixsocket: self.serversock = socket.socket(socket.AF_UNIX) # mirror i-PI's /tmp/ipi_* default when given a name if not unixsocket.startswith("/"): unixsocket = f"/tmp/socketmxl_{unixsocket}" self.unixsocket_path = unixsocket if os.path.exists(self.unixsocket_path): probe = socket.socket(socket.AF_UNIX) try: probe.settimeout(0.25) probe.connect(self.unixsocket_path) except FileNotFoundError: pass except ConnectionRefusedError: try: os.unlink(self.unixsocket_path) except FileNotFoundError: pass else: probe.close() raise RuntimeError( f"Socket path {self.unixsocket_path} already in use" ) finally: try: probe.close() except Exception: pass self.serversock.bind(unixsocket) self._where = unixsocket else: self.serversock = socket.socket(socket.AF_INET) self.serversock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) host = host or "" port = port or 31415 self.serversock.bind((host, port)) self._where = f"{host}:{port}" self.serversock.listen(16384) self.serversock.settimeout(0.25) self.timeout = float(timeout) self.latency = float(latency) # key: molecule_id or temp id self.clients: Dict[int, _ClientState] = {} # peer -> molecule_id self.addrmap: Dict[str, int] = {} self._stop = False self._lock = threading.RLock() self._accept_th = threading.Thread(target=self._accept_loop, daemon=True) self._accept_th.start() # assign a molecular id accumulator self._molecule_id_counter = 0 # Persistent selector — clients are registered on bind, not per step. self._selector = selectors.DefaultSelector() # Reusable scratch buffers on the hot path: # _scratch_send: the 196-byte FIELDDATA+GETSOURCE blob, with # the 24-byte field window at _SEND_FIELD_OFFSET patched in # place each step via struct.pack_into (no per-step allocation). # _scratch_recv: the 124-byte fixed SOURCEREADY reply, filled # by a single recv_into loop and parsed via struct. self._scratch_send = bytearray(_SEND_TEMPLATE) self._scratch_recv = bytearray(_REPLY_FIXED_LEN) self._scratch_recv_mv = memoryview(self._scratch_recv) # molecule_id -> _ClientState (locked client) self.bound: Dict[int, _ClientState] = {} # molecule ids we expect to serve self.expected: set[int] = set() # global pause when any driver is down self.paused = False # holds a frozen barrier until it successfully commits self._inflight = None
def _accept_loop(self): """ Accept-loop thread: accept new connections and register temporary clients. """ while not self._stop: try: csock, addr = self.serversock.accept() except socket.timeout: continue except OSError: break # NEW: trim latency and keep long-lived connections healthy try: csock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) # Only for AF_INET; will raise on AF_UNIX csock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except (OSError, AttributeError): pass # AF_UNIX or platform without TCP_NODELAY peer = addr if isinstance(addr, str) else f"{addr[0]}:{addr[1]}" csock.settimeout(self.timeout) st = _ClientState(sock=csock, address=peer, molecule_id=-1) with self._lock: # temp key: use id(csock) until INIT binds molecule_id self.clients[id(csock)] = st def _maybe_init_client(self, st: _ClientState, init_payload: dict): """ Send INIT to a client with the given payload and mark it initialized. Parameters ---------- st : _ClientState Client state to initialize. init_payload : dict Initialization payload (e.g., contains ``"molecule_id"``). """ _pack_init(st.sock, init_payload) st.initialized = True def _register_sock(self, sock: socket.socket, molid: int) -> None: """ Register a client's socket with the persistent selector. Called once at bind time. If the socket is already registered (for example after a rebind/reconnect), we replace the old registration so future ``select`` events carry the up-to-date molecule id. Parameters ---------- sock : socket.socket The client socket. molid : int Molecule id to attach as the selector ``data`` payload. """ try: self._selector.register(sock, selectors.EVENT_READ, data=int(molid)) except (KeyError, ValueError): # Already registered under this fd — swap the data payload. try: self._selector.unregister(sock) self._selector.register(sock, selectors.EVENT_READ, data=int(molid)) except (KeyError, ValueError, OSError): pass except OSError: pass def _unregister_sock(self, sock: socket.socket) -> None: """ Unregister a client socket from the persistent selector. Safe to call with a socket that was never registered or has already been closed; errors are swallowed so disconnect paths stay simple. Parameters ---------- sock : socket.socket The client socket. """ try: self._selector.unregister(sock) except (KeyError, ValueError, OSError): pass def _mark_dead( self, st: _ClientState, molid: Optional[int] = None, reason: Optional[str] = None, ) -> bool: """ Mark a client dead, unregister it from the selector, and clear binding. This centralizes the bookkeeping that used to be duplicated across the STATUS-based sweep and the shutdown paths. It is safe to call from any phase and takes ``self._lock`` only briefly for the ``bound`` mutation so blocking I/O never runs while the lock is held. Parameters ---------- st : _ClientState The client whose socket failed. molid : int or None, optional The molecule id the client was bound to. Falls back to ``st.molecule_id`` when ``None``. reason : str or None, optional Short tag for the disconnect log line (e.g. ``"send"``, ``"recv"``). When ``None`` the bare ``DISCONNECTED: ...`` form is logged. Returns ------- bool ``True`` if a bound molecule was actually released, else ``False``. """ st.alive = False self._unregister_sock(st.sock) if molid is None: molid = st.molecule_id if molid is not None and molid >= 0: with self._lock: if self.bound.get(molid) is st: tag = f" ({reason})" if reason else "" self._log(f"DISCONNECTED{tag}: mol {molid} from {st.address}") self.bound[molid] = None return True return False def _dispatch_field( self, st: _ClientState, blob: "bytes | bytearray | memoryview", meta: dict ) -> None: """ Send a pre-packed FIELDDATA+GETSOURCE blob to one client in a single call. This is the hot-path send used by :meth:`step_barrier`. The caller is responsible for packing the field vector into the shared scratch buffer (via ``struct.pack_into``) so a whole group of clients sharing the same field can reuse the same blob. Parameters ---------- st : _ClientState Target client state. blob : bytes-like Pre-packed 196-byte request buffer. meta : dict Optional metadata to attach to this send (stored in ``st.extras``). Raises ------ _SocketClosed or OSError If the client disconnects during send. The caller is responsible for calling :meth:`_mark_dead`. """ st.sock.sendall(blob) st.pending_send = True if meta: st.extras.update(meta) def _read_source_ready(self, st: _ClientState) -> Tuple[np.ndarray, bytes]: """ Read a SOURCEREADY/FORCEREADY reply into the shared scratch buffer. The reply's fixed 124-byte prefix (header, energy, nat, forces, virial, extra_len) is drained in a single ``recv_into`` loop into ``self._scratch_recv`` and parsed with ``struct.unpack_from`` — no numpy temporaries, no per-field ``_recv_array`` calls. Only a single 3-element ``np.array`` is allocated at the end to carry the amplitude back to the caller. The shared scratch buffer is safe because :meth:`step_barrier` drains selector events serially in the main thread — only one reply is being parsed at any given time. Parameters ---------- st : _ClientState Client whose reply is being drained. Assumes the hub has already sent the combined FIELDDATA+GETSOURCE request and the kernel reported the socket readable. Returns ------- tuple ``(amp_vec3, extra_bytes)`` where ``amp_vec3`` is a ``(3,)`` ``np.ndarray`` and ``extra_bytes`` is the trailing variable blob. Raises ------ _SocketClosed or OSError If the peer disconnects, the header is not SOURCEREADY, or the reported ``nat`` is not the EM-protocol-expected value of 1. """ sock = st.sock mv = self._scratch_recv_mv n = 0 while n < _REPLY_FIXED_LEN: r = sock.recv_into(mv[n:], _REPLY_FIXED_LEN - n) if r == 0: raise _SocketClosed("Peer closed") n += r # Header must be SOURCEREADY (the 12-byte ASCII tag, space-padded). if bytes(mv[:HEADER_LEN]).rstrip() != SOURCEREADY: raise _SocketClosed( f"Expected {SOURCEREADY!r}, got {bytes(mv[:HEADER_LEN]).rstrip()!r}" ) # EM protocol contract: drivers always send nat=1. nat = _STRUCT_I.unpack_from(mv, _REPLY_NAT_OFFSET)[0] if nat != 1: raise _SocketClosed(f"EM fast-path expected nat=1, got nat={nat}") fx, fy, fz = _STRUCT_3D.unpack_from(mv, _REPLY_FORCES_OFFSET) extra_len = _STRUCT_I.unpack_from(mv, _REPLY_EXTRA_LEN_OFFSET)[0] extra = _recvall(sock, extra_len) if extra_len > 0 else b"" amp = np.array((fx, fy, fz), dtype=float) st.last_amp = amp st.pending_send = False return amp, extra def _progress_binds_locked(self, init_payloads: Dict[int, dict]) -> None: """ Drive INIT handshakes for any fresh (unbound) clients. Walks ``self.clients`` for entries whose ``molecule_id < 0`` (the temp state created by the accept loop) and, for each one, picks an expected molecule ID from ``init_payloads`` that is not yet bound and sends ``INIT`` directly. This replaces the old STATUS/NEEDINIT round-trip: both the Python and LAMMPS drivers accept INIT unconditionally as the first message from the hub, so the extra poll is unnecessary. Parameters ---------- init_payloads : dict[int, dict] Mapping of molecule ID to the INIT payload to send for that ID. Notes ----- This method assumes ``self._lock`` is held by the caller. """ pending_ids = [ int(mid) for mid in init_payloads.keys() if self.bound.get(int(mid)) is None ] if not pending_ids: return fresh_clients = [ (k, st) for k, st in list(self.clients.items()) if st is not None and st.alive and st.molecule_id < 0 ] for st_key, st in fresh_clients: if not pending_ids: break chosen = pending_ids.pop(0) payload = init_payloads.get(chosen) or {"molecule_id": chosen} payload = {**payload, "molecule_id": chosen} try: self._bind_client_locked(st, int(chosen), payload, st_key) except (socket.timeout, _SocketClosed, OSError): st.alive = False # put the id back so another fresh client can claim it pending_ids.insert(0, chosen) def _bind_client_locked( self, st: _ClientState, molid: int, init_payload: dict, st_key ): """ Bind a client to a molecule ID if available and perform INIT. Parameters ---------- st : _ClientState Client to bind. molid : int Molecule ID to bind to. init_payload : dict INIT payload to send. st_key : int Temporary key under which the client is stored. Returns ------- bool ``True`` if binding succeeded, otherwise ``False``. """ if self.bound.get(molid) is None: self._maybe_init_client(st, init_payload) st.molecule_id = molid self.bound[molid] = st self.addrmap[st.address] = molid self.clients[molid] = st if st_key != molid: try: del self.clients[st_key] except KeyError: pass # Register with the persistent selector so Phase B of # step_barrier doesn't have to re-register on every call. self._register_sock(st.sock, molid) address = st.address self._log(f"CONNECTED: mol {molid} <- {address}") # NEW: this molid is part of a frozen barrier -> force re-dispatch self._reset_inflight_for(molid) st.pending_send = False # defensive: this is a fresh socket return True return False def _log(self, *a): """ Log a message with the ``[SocketHub]`` prefix. """ print("[SocketHub]", *a) def _pause(self): """ Pause the hub (used when a driver disconnects mid-barrier). """ self.paused = True def _resume(self): """ Resume the hub after a pause. """ self.paused = False def _reset_inflight_for(self, molid: int): """ Force re-dispatch for ``molid`` in a frozen barrier after reconnect. Parameters ---------- molid : int Molecule ID to reset in the current barrier state. """ if self._inflight and (molid in self._inflight["wants"]): self._inflight["sent"][molid] = False def _find_free_molecule_id(self) -> int: """ Find and return an available molecule ID not already registered. Returns ------- int A unique molecule ID. """ while True: molecule_id = self._molecule_id_counter self._molecule_id_counter += 1 if molecule_id not in self.expected: return molecule_id # -------------- public API --------------
[docs] def register_molecule(self, molecule_id: int) -> None: """ Reserve a slot for a given molecule ID (client may connect later). Parameters ---------- molecule_id : int Molecule ID to register. Raises ------ ValueError If the molecule ID is already registered. """ with self._lock: # If already registered, raising a ValueError if molecule_id in self.expected: raise ValueError(f"Molecule ID {molecule_id} already registered!") # No explicit state needed yet; client binds on INIT. self.expected.add(int(molecule_id)) self.bound.setdefault(int(molecule_id), None)
[docs] def register_molecule_return_id(self) -> int: """ Reserve a slot for a molecule and return an auto-assigned ID. Returns ------- int The assigned unique molecule ID. """ with self._lock: # Find an available molecule_id molecule_id = self._find_free_molecule_id() self.register_molecule(molecule_id) return molecule_id
[docs] def step_barrier( self, requests: Dict[int, dict], timeout: Optional[float] = None ) -> Dict[int, np.ndarray]: """ Barrier step: dispatch fields and collect source amplitudes from all clients. Coordinates sending fields, waiting for results, and jointly committing the results once every requested molecule is ready. A frozen barrier is reused if a disconnect occurs mid-step. Parameters ---------- requests : dict[int, dict] Mapping from molecule ID to request dict with keys: - ``"efield_au"`` : array-like ``(3,)`` field vector in a.u. - ``"meta"`` : dict, optional metadata per send. - ``"init"`` : dict, optional INIT payload for first bind. timeout : float, optional Maximum time (seconds) to wait for the barrier to complete. Defaults to the hub's ``timeout`` setting. Returns ------- dict[int, dict] Mapping ``molid -> {"amp": ndarray(3,), "extra": bytes}``. Returns ``{}`` when paused, on abort, or if the barrier is incomplete. """ if self.paused: return {} deadline = time.time() + (timeout if timeout is not None else self.timeout) results: Dict[int, dict] = {} # If a barrier is already in flight, ignore new 'requests' and reuse the frozen one. if self._inflight is None: wants = set(int(k) for k in requests.keys()) self._inflight = { "wants": wants, "efields": { int(mid): np.asarray( requests[mid]["efield_au"], dtype=DT_FLOAT ).copy() for mid in wants }, "meta": {int(mid): requests[mid].get("meta", {}) for mid in wants}, "sent": {int(mid): False for mid in wants}, } wants = set(self._inflight["wants"]) # --- hard gate: do not dispatch fields until everyone is bound --- with self._lock: if not self.all_bound(wants, require_init=True): init_payloads = { int(mid): ( requests.get(mid, {}).get("init") or {"molecule_id": int(mid)} ) for mid in wants } self._progress_binds_locked(init_payloads) return {} # Snapshot the (mid, st, efield, meta) tuples we will send to. # Everything below runs without self._lock held, so the accept # thread and background bookkeeping cannot be starved by blocking # send/recv syscalls. snapshot = [] for mid in wants: if self._inflight["sent"].get(mid, False): continue st = self.bound.get(mid) if st is None or not st.alive: self._pause() self._reset_inflight_for(mid) return {} snapshot.append( ( int(mid), st, self._inflight["efields"][mid], self._inflight["meta"][mid], ) ) # --- Phase A: pipeline dispatch (FIELDDATA + GETSOURCE in one send) --- # # We reuse a single 196-byte scratch bytearray for every send; only # the 24-byte field window at offset _SEND_FIELD_OFFSET is rewritten # via struct.pack_into. Clients sharing an identical field vector # (common in Meep runs that dedup by polarization fingerprint) are # grouped so we pack once per unique field instead of once per client. scratch = self._scratch_send groups: Dict[Tuple[float, float, float], list] = {} for mid, st, efield, meta in snapshot: ef = np.asarray(efield, dtype=DT_FLOAT).reshape(3) key = (float(ef[0]), float(ef[1]), float(ef[2])) groups.setdefault(key, []).append((mid, st, meta)) for fkey, members in groups.items(): _STRUCT_3D.pack_into(scratch, _SEND_FIELD_OFFSET, fkey[0], fkey[1], fkey[2]) for mid, st, meta in members: try: self._dispatch_field(st, scratch, meta) self._inflight["sent"][mid] = True except (socket.timeout, _SocketClosed, OSError): self._mark_dead(st, mid, reason="send") self._pause() self._reset_inflight_for(mid) return {} # --- Phase B: collect SOURCEREADY replies via the persistent selector --- # # The selector has every bound client registered (from _bind_client_locked), # so we do NOT register per call. Phase B just waits for readable events # on the sockets belonging to mids in `pending_mids`, parses their # replies via the shared scratch recv buffer, and discards them. pending_mids: set[int] = set(int(mid) for mid in wants) sel = self._selector while pending_mids: remaining = deadline - time.time() if remaining <= 0: break # Cap the wait so we periodically re-check the deadline. events = sel.select(timeout=min(remaining, 1.0)) if not events: continue for key, _mask in events: mid = key.data if mid not in pending_mids: # Spurious wake (stale registration or unrelated driver); # leave it for later and keep draining our own mids. continue with self._lock: st = self.bound.get(mid) if st is None or not st.alive: pending_mids.discard(mid) self._pause() self._reset_inflight_for(mid) return {} try: amp, extra = self._read_source_ready(st) results[mid] = {"amp": amp, "extra": extra} pending_mids.discard(mid) except (socket.timeout, _SocketClosed, OSError): self._mark_dead(st, mid, reason="recv") pending_mids.discard(mid) self._pause() self._reset_inflight_for(mid) return {} if pending_mids: # Timed out waiting for replies; keep the frozen barrier for retry. return {} # SUCCESS — clear the frozen barrier self._inflight = None return results
[docs] def all_bound(self, molecule_ids, require_init=True): """ Check if all given molecule IDs are bound (and optionally initialized). Parameters ---------- molecule_ids : iterable of int Molecule IDs to check. require_init : bool, default: True Also require that clients completed INIT. Returns ------- bool ``True`` if all are bound (and initialized if requested), else ``False``. """ with self._lock: for mid in molecule_ids: st = self.bound.get(int(mid)) if st is None or not st.alive: return False if require_init and not st.initialized: return False return True
[docs] def wait_until_bound(self, init_payloads: dict, require_init=True, timeout=None): """ Block until all requested molecule IDs are bound (and optionally initialized). Parameters ---------- init_payloads : dict[int, dict] Mapping from molecule ID to INIT payload to use on bind. require_init : bool, default: True Also require that clients completed INIT. timeout : float or None, optional Maximum time to wait (seconds). Uses hub default if ``None``. Returns ------- bool ``True`` if all requested IDs became bound within the time limit, else ``False``. """ wanted = {int(k) for k in init_payloads.keys()} deadline = time.time() + (timeout if timeout is not None else self.timeout) payloads = {int(mid): init_payloads[mid] for mid in init_payloads.keys()} while True: if self.all_bound(wanted, require_init=require_init): self._resume() return True # Push INIT to any fresh unbound clients. The accept loop has already # enqueued them; we no longer use STATUS to probe for NEEDINIT. with self._lock: pending_ids = {mid for mid in wanted if self.bound.get(mid) is None} if pending_ids: sub_payloads = { mid: payloads.get(mid, {"molecule_id": mid}) for mid in pending_ids } self._progress_binds_locked(sub_payloads) if timeout is not None and time.time() > deadline: return False time.sleep(self.latency)
[docs] def graceful_shutdown(self, reason: Optional[str] = None, wait: float = 2.0): """ Politely ask all connected drivers to exit and wait briefly for ``BYE``. Parameters ---------- reason : str or None, optional Optional reason to log for shutdown. wait : float, default: 2.0 Seconds to wait for clean replies. """ with self._lock: for st in list(self.clients.values()): if not st or not st.alive: continue try: _send_msg(st.sock, STOP) except Exception: if self._mark_dead(st): self._pause() deadline = time.time() + float(wait) while time.time() < deadline: time.sleep(self.latency) with self._lock: for st in list(self.clients.values()): if not st or not st.alive: continue try: # Make reads snappy during shutdown st.sock.settimeout(self.latency) msg = _recv_msg(st.sock) if msg == BYE: # Clean close on our side self._mark_dead(st) try: st.sock.shutdown(socket.SHUT_RDWR) except Exception: pass try: st.sock.close() except Exception: pass except (socket.timeout, _SocketClosed, OSError): # Either no message yet or peer closed already; keep sweeping continue
[docs] def stop(self): """ Stop accepting new connections, request clients to exit, and close sockets. Also removes the UNIX socket path if one was created. """ # First, stop accepting new connections self._stop = True try: self.serversock.close() except Exception: pass # Then, gracefully end existing sessions try: self.graceful_shutdown(wait=max(2.0, 10 * self.latency)) finally: with self._lock: for st in list(self.clients.values()): self._unregister_sock(st.sock) try: st.sock.close() except Exception: pass try: self._selector.close() except Exception: pass # if unix socket, remove the path if self.unixsocket_path and os.path.exists(self.unixsocket_path): os.unlink(self.unixsocket_path) print(f"[SocketHub] Unlinked unix socket path {self.unixsocket_path}")