# -*- coding: utf-8 -*-
""" Socket supply cog for DeadBeat """
from __future__ import absolute_import, division, print_function, unicode_literals
import socket
import select
import os
import os.path as pth
import sys
import re
import errno
import logging
from . import conf
from .movement import Cog, itemgetter, itemsetter, uuidsetter, str_type
__all__ = ["SocketSupply"]
#: Socket listener configuration insert
socket_supply_base_config = (
conf.cfg_item(
"path", str_type,
"Path of listening socket", default=pth.splitext(sys.argv[0])[0] + ".socket"),
)
[docs]class SocketSupply(Cog):
""" Socket listener, accepting char separated lines """
[docs] def __init__(self, esc, path, eols="\0\n", factory=dict, data_set=None, id_get=None, id_set=None):
""" Initialize FileWatcherSupply, open files, register inotify and timeouts.
:param esc: `escapement` singleton.
:param path: Unix socket path to listen to.
:param eols: List of characters understood as record separators.
:param factory: Callable returning new event data.
:param data_set: Setter for inserting new lines into data.
:param id_get: Getter of ID.
:param id_set: Generator and setter of ID.
"""
self.esc = esc
self.path = path
self.eols = eols
self.factory = factory
self.data_set = data_set or itemsetter("input")
self.id_get = id_get or itemgetter("ID")
self.id_set = id_set or uuidsetter("ID")
self.buf = b""
self.eol_re = re.compile('|'.join(re.escape(eol) for eol in eols))
self.conns = {}
self.bufs = {}
self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
os.remove(self.path)
except OSError:
pass
self.sock.bind(self.path)
self.sock.setblocking(0)
self.sock.listen(1)
self.esc.fd_register(self.sock, select.EPOLLIN, self._accept)
def _accept(self, fd, poll_event):
conn, addr = self.sock.accept()
conn.setblocking(0)
self.conns[conn.fileno()] = (conn, addr)
logging.debug("Accepted connection %i" % conn.fileno())
self.esc.fd_register(conn, select.EPOLLIN|select.EPOLLPRI, self._handle)
self.esc.fd_register(conn, select.EPOLLERR|select.EPOLLHUP, self._hup)
def _make_data(self, line):
data = self.factory()
data = self.id_set(data)
data = self.data_set(data, line)
return data
def _handle(self, fd, poll_event):
try:
conn, addr = self.conns[fd]
except KeyError:
return
buf = self.bufs.get(fd, "")
while True:
try:
new_data = conn.recv(4096)
except IOError as e:
if e.errno == errno.EAGAIN:
break
raise
if not new_data:
break
buf += new_data.decode("utf-8")
lines = self.eol_re.split(buf)
if len(lines)>1:
self.bufs[fd] = lines.pop()
return [self._make_data(line) for line in lines]
def _hup(self, fd, poll_event):
try:
conn, addr = self.conns[fd]
except KeyError:
return
conn.close()
del self.conns[fd]
logging.debug("Closed connection %i" % fd)
buf = self.bufs.pop(fd, None)
if buf:
return (self._make_data(buf),)