Source code for autopush.websocket

"""Websocket Protocol handler and HTTP Endpoints for Connection Node

Private HTTP Endpoints
======================

These HTTP endpoints are only for communication from endpoint nodes and must
not be publicly exposed.

.. http:put:: /push/(uuid:uaid)

    Send a notification to a connected client with the given `uaid`.

    :statuscode 200: Client is connected and delivery will be attempted.
    :statuscode 404: Client is not connected to this node.
    :statuscode 503: Client is connected, but currently busy.

.. http:put:: /notif/(uuid:uaid)

    Trigger a stored notification check for a connected client.

    :statuscode 200: Client is connected, and has started checking.
    :statuscode 202: Client is connected but busy, will check notifications
                     when not busy.
    :statuscode 404: Client is not connected to this node.

.. http:delete:: /notif/(uuid:uaid)/(int:connected_at)

    Immediately drop a client of this `uaid` if its connection time matches the
    `connected_at` provided.

"""
import json
import time
import uuid
from collections import defaultdict
from functools import partial, wraps
from random import randrange

import attr
from attr import (
    Factory,
    attrs,
    attrib
)
from autobahn.twisted.resource import WebSocketResource
from autobahn.twisted.websocket import (
    WebSocketServerFactory,
    WebSocketServerProtocol
)
from autobahn.websocket.protocol import ConnectionRequest  # noqa
from botocore.exceptions import ClientError
from botocore.vendored.requests.packages import urllib3
from twisted.internet import reactor
from twisted.internet.defer import (
    Deferred,
    DeferredList,
    CancelledError
)
from twisted.internet.error import (
    ConnectError,
    ConnectionClosed,
    DNSLookupError)
from twisted.internet.interfaces import IProducer
from twisted.internet.threads import deferToThread
from twisted.logger import Logger
from twisted.protocols import policies
from twisted.python import failure
from twisted.web._newclient import ResponseFailed
from twisted.web.client import Agent  # noqa
from twisted.web.resource import Resource
from twisted.web.server import Site
from typing import (  # noqa
    Any,
    Callable,
    Dict,
    List,
    Optional,
    Tuple,
)
from zope.interface import implementer

from autopush import __version__
from autopush.base import BaseHandler
from autopush.config import AutopushConfig  # noqa
from autopush.db import (
    has_connected_this_month,
    hasher,
    generate_last_connect,
)
from autopush.db import DatabaseManager, Message  # noqa
from autopush.exceptions import MessageOverloadException, ItemNotFound
from autopush.noseplugin import track_object
from autopush.protocol import IgnoreBody
from autopush.metrics import IMetrics, make_tags  # noqa
from autopush.ssl import AutopushSSLContextFactory  # noqa
from autopush.utils import (
    parse_user_agent,
    validate_uaid,
    WebPushNotification,
    ms_time
)


USER_RECORD_VERSION = 1
DEFAULT_WS_ERR = "http://autopush.readthedocs.io/en/" \
                 "latest/api/websocket.html#private-http-endpoint"
# codes expected from the client (and emitted as a metric tag)
NACK_CODES = range(301, 304)


def extract_code(data):
    """Extracts and converts a code key if found in data dict"""
    code = data.get("code", None)
    if code and isinstance(code, int):
        code = code
    else:
        code = 0
    return code


[docs]def log_exception(func): """Exception Logger Decorator for protocol methods""" @wraps(func) def wrapper(self, *args, **kwargs): try: return func(self, *args, **kwargs) except Exception: if self._log_exc: self.log_failure(failure.Failure()) else: raise return wrapper
@attrs(slots=True) class SessionStatistics(object): """Websocket Session Statistics Tracks statistics about the session that are logged when the websocket session has been closed. """ # User data uaid_hash = attrib(default="") # type: str uaid_reset = attrib(default=False) # type: bool existing_uaid = attrib(default=False) # type: bool connection_type = attrib(default="") # type: str host = attrib(default="") # type: str ua_os_family = attrib(default="") # type: str ua_os_ver = attrib(default="") # type: str ua_browser_family = attrib(default="") # type: str ua_browser_ver = attrib(default="") # type: str connection_time = attrib(default=0) # type: int # Usage data direct_acked = attrib(default=0) # type: int direct_storage = attrib(default=0) # type: int stored_retrieved = attrib(default=0) # type: int stored_acked = attrib(default=0) # type: int nacks = attrib(default=0) # type: int unregisters = attrib(default=0) # type: int registers = attrib(default=0) # type: int def logging_data(self): # type: () -> Dict[str, Any] return attr.asdict(self) @implementer(IProducer) @attrs(slots=True) class PushState(object): """Compact storage of a PushProtocolConnection's state""" db = attrib() # type: DatabaseManager _callbacks = attrib(default=Factory(list)) # type: List[Deferred] stats = attrib( default=Factory(SessionStatistics)) # type: SessionStatistics _user_agent = attrib(default=None) # type: Optional[str] _base_tags = attrib(default=Factory(list)) # type: List[str] raw_agent = attrib(default=Factory(dict)) # type: Optional[Dict[str, str]] _should_stop = attrib(default=False) # type: bool _paused = attrib(default=False) # type: bool _uaid_obj = attrib(default=None) # type: Optional[uuid.UUID] _uaid_hash = attrib(default=None) # type: Optional[str] last_ping = attrib(default=0.0) # type: float check_storage = attrib(default=False) # type: bool router_type = attrib(default=None) # type: Optional[str] connected_at = attrib(default=Factory(ms_time)) # type: float ping_time_out = attrib(default=False) # type: bool # Message table rotation message_month = attrib(init=False) # type: str rotate_message_table = attrib(default=False) # type: bool _check_notifications = attrib(default=False) # type: bool _more_notifications = attrib(default=False) # type: bool # Timestamped message handling defaults scan_timestamps = attrib(default=False) # type: bool current_timestamp = attrib(default=None) # type: Optional[int] # Hanger for common actions we defer _notification_fetch = attrib(default=None) # type: Optional[Deferred] _register = attrib(default=None) # type: Optional[Deferred] # Reflects Notification's sent that haven't been ack'd This is # simplepush style by default updates_sent = attrib(default=Factory(dict)) # type: Dict # Track Notification's we don't need to delete separately This is # simplepush style by default direct_updates = attrib(default=Factory(dict)) # type: Dict # Whether this record should be reset after delivering stored # messages _reset_uaid = attrib(default=False) # type: bool @classmethod def from_request(cls, request, **kwargs): # type: (ConnectionRequest, **Any) -> PushState return cls( user_agent=request.headers.get("user-agent"), stats=SessionStatistics(host=request.host), **kwargs ) def __attrs_post_init__(self): """Initialize PushState""" if self._user_agent: dd_tags, self.raw_agent = parse_user_agent(self._user_agent) for tag_name, tag_value in dd_tags.items(): setattr(self.stats, tag_name, tag_value) self._base_tags.append("%s:%s" % (tag_name, tag_value)) self.stats.ua_os_ver = self.raw_agent["ua_os_ver"] self.stats.ua_browser_ver = self.raw_agent["ua_browser_ver"] if self.stats.host: self._base_tags.append("host:%s" % self.stats.host) # Message table rotation initial settings self.message_month = self.db.current_msg_month self.reset_uaid = False @property def user_agent(self): # type: () -> str return self._user_agent or "None" @property def reset_uaid(self): # type: () -> bool return self._reset_uaid @reset_uaid.setter def reset_uaid(self, value): if value: self._reset_uaid = True self.stats.uaid_reset = True else: self._reset_uaid = False @property def uaid_obj(self): # type: () -> Optional[uuid.UUID] return self._uaid_obj @property def uaid_hash(self): # type: () -> str return self._uaid_hash @property def uaid(self): # type: () -> Optional[str] return self._uaid_obj.hex if self._uaid_obj else None @uaid.setter def uaid(self, value): self._uaid_obj = uuid.UUID(value) if value else None self._uaid_hash = hasher(value) if value else "" self.stats.uaid_hash = self._uaid_hash def init_connection(self): """Set the connection type for the client""" self._base_tags.append("use_webpush:True") self.router_type = self.stats.connection_type = "webpush" # Update our message tracking for webpush self.updates_sent = defaultdict(lambda: []) self.direct_updates = defaultdict(lambda: []) def pauseProducing(self): """IProducer implementation tracking if we should pause output""" self._paused = True def resumeProducing(self): """IProducer implementation tracking when we should resume output""" self._paused = False def stopProducing(self): """IProducer implementation tracking when we should stop""" self._paused = True self._should_stop = True
[docs]class PushServerProtocol(WebSocketServerProtocol, policies.TimeoutMixin): """Main Websocket Connection Protocol""" log = Logger() # Testing purposes parent_class = WebSocketServerProtocol randrange = randrange _log_exc = True sent_notification_count = 0 @property def conf(self): # type: () -> AutopushConfig return self.factory.conf @property def db(self): # type: () -> DatabaseManager return self.factory.db @property def metrics(self): # type: () -> IMetrics return self.db.metrics # Defer helpers
[docs] def deferToThread(self, func, *args, **kwargs): # type (Callable[..., Any], *Any, **Any) -> Deferred """deferToThread helper that tracks defers outstanding""" d = deferToThread(func, *args, **kwargs) self.ps._callbacks.append(d) def f(result): if d in self.ps._callbacks: self.ps._callbacks.remove(d) return result d.addBoth(f) return d
[docs] def deferToLater(self, when, func, *args, **kwargs): # type: (float, Callable[..., Any], *Any, **Any) -> Deferred """deferToLater helper that tracks defers outstanding""" def cancel(d): d._cancelled = True d = Deferred(canceller=cancel) d._cancelled = False self.ps._callbacks.append(d) def f(): if d in self.ps._callbacks: self.ps._callbacks.remove(d) # Don't run if the deferred was cancelled already if d._cancelled: return try: result = func(*args, **kwargs) d.callback(result) except Exception: d.errback(failure.Failure()) reactor.callLater(when, f) return d
def trap_cancel(self, fail): fail.trap(CancelledError) def trap_connection_err(self, fail): fail.trap(ConnectError, ConnectionClosed, ResponseFailed, DNSLookupError) def trap_boto3_err(self, fail): # trap boto3 ConnectTimeoutError in retry fail.trap(urllib3.exceptions.ConnectTimeoutError)
[docs] def force_retry(self, func, *args, **kwargs): # type: (Callable[..., Any], *Any, **Any) -> Deferred """Forcefully retry a function in a thread until it doesn't error Note that this does not use ``self.deferToThread``, so this will continue to retry even if the client drops. """ def wrapper(result, *w_args, **w_kwargs): if isinstance(result, failure.Failure): # This is an exception, log it self.log_failure(result) d = deferToThread(func, *args, **kwargs) d.addErrback(wrapper) return d d = deferToThread(func, *args, **kwargs) d.addErrback(wrapper) return d
@property def base_tags(self): """Property that uses None if there's no tags due to a DataDog library bug""" return self.ps._base_tags if self.ps._base_tags else None
[docs] def log_failure(self, failure, **kwargs): """Log a twisted failure out through twisted's log.failure""" self.log.failure(format="Unexpected error", failure=failure, **kwargs)
@property def paused(self): """Indicates if we are paused for output production or not""" return self.ps._paused
[docs] @log_exception def _sendAutoPing(self): """Override for sanity checking during auto-ping interval""" # Note: it's possible (but tracking information has yet to prove) that # a websocket connection could persist longer than the message record # expiration time (~30d), which might cause some problems. Most # websocket connections time out far, far earlier than that, which # resets the record expiration times. if not self.ps.uaid: # No uaid yet, drop the connection self.sendClose() elif self.factory.clients.get(self.ps.uaid) != self: # UAID, but we're not in clients anymore for some reason self.sendClose() return WebSocketServerProtocol._sendAutoPing(self)
[docs] @log_exception def sendClose(self, code=None, reason=None): """Override to add tracker that ensures the connection is truly torn down""" reactor.callLater(10+self.closeHandshakeTimeout, self.nukeConnection) return WebSocketServerProtocol.sendClose(self, code, reason)
[docs] @log_exception def nukeConnection(self): """Aggressive connection shutdown using abortConnection if onClose still hadn't run by this point""" # Did onClose get called? If so, we shutdown properly, no worries. if hasattr(self, "_shutdown_ran"): return self.transport.abortConnection()
[docs] @log_exception def onConnect(self, request): """autobahn onConnect handler for when a connection has started""" track_object(self, msg="onConnect Start") self.ps = PushState.from_request(request=request, db=self.db) # Setup ourself to handle producing the data self.transport.bufferSize = 2 * 1024 try: self.transport.registerProducer(self.ps, True) except RuntimeError: # HACK: Autobahn/twisted/h2 hacks mess this up, ensure we can # register the producer self.transport.unregisterProducer() self.transport.registerProducer(self.ps, True) if self.conf.hello_timeout > 0: self.setTimeout(self.conf.hello_timeout)
############################################################# # Connection Methods #############################################################
[docs] @log_exception def processHandshake(self): """Disable host port checking on nonstandard ports since some clients are buggy and don't provide it""" track_object(self, msg="processHandshake") port = self.conf.port hide = port != 80 and port != 443 old_port = self.factory.externalPort try: if hide: self.factory.externalPort = None return self.parent_class.processHandshake(self) except UnicodeEncodeError: self.failHandshake("Error reading handshake data") finally: if hide: self.factory.externalPort = old_port
[docs] @log_exception def onMessage(self, payload, isBinary): """autobahn onMessage processor for incoming messages""" if isBinary: self.sendClose() return track_object(self, msg="onMessage") data = None try: data = json.loads(payload.decode('utf8')) except (TypeError, ValueError): pass if not isinstance(data, dict): self.sendClose() return # Without a UAID, hello must be next if not self.ps.uaid: return self.process_hello(data) # Ping's get a ping reply if data == {}: return self.process_ping() cmd = data.get("messageType") # We're no longer idle, prevent early connection closure. self.resetTimeout() try: if cmd == "hello": return self.process_hello(data) elif cmd == "register": return self.process_register(data) elif cmd == "unregister": return self.process_unregister(data) elif cmd == "ack": return self.process_ack(data) elif cmd == "nack": return self.process_nack(data) else: self.sendClose() finally: # Done processing, start idle. self.resetTimeout()
[docs] def timeoutConnection(self): """Idle timer fired.""" self.sendClose()
[docs] def onAutoPingTimeout(self): """Override to track that this shut-down is from a ping timeout""" self.ps.ping_time_out = True WebSocketServerProtocol.onAutoPingTimeout(self)
[docs] @log_exception def onClose(self, wasClean, code, reason): """autobahn onClose handler for shutting down the connection and any outstanding deferreds related to this connection""" try: uaid = self.ps.uaid self._shutdown_ran = True self.ps._should_stop = True self.ps._check_notifications = False except AttributeError: # pragma: nocover # Sometimes in odd production cases, onClose will be called without # onConnect being called to set this up. uaid = None # Log out the disconnect reason if uaid: self.cleanUp(wasClean, code, reason)
[docs] def cleanUp(self, wasClean, code, reason): """Thorough clean-up method to cancel all remaining deferreds, and send connection metrics in""" elapsed = (ms_time() - self.ps.connected_at) / 1000.0 self.metrics.timing("ua.connection.lifespan", duration=elapsed, tags=self.base_tags) self.ps.stats.connection_time = int(elapsed) # Cleanup our client entry if self.ps.uaid and self.factory.clients.get(self.ps.uaid) == self: del self.factory.clients[self.ps.uaid] # Cancel any outstanding deferreds that weren't already called for d in self.ps._callbacks: if not d.called: d.cancel() # Attempt to deliver any notifications not originating from storage if self.ps.direct_updates: defers = [] for notifs in self.ps.direct_updates.values(): notifs = filter(lambda x: x.ttl != 0, notifs) self.ps.stats.direct_storage += len(notifs) defers.extend(map(self._save_webpush_notif, notifs)) # Tag on the notifier once everything has been stored dl = DeferredList(defers) dl.addBoth(self._lookup_node) # Delete and remove remaining dicts and lists del self.ps.direct_updates del self.ps.updates_sent # Log out sessions stats self.log.info("Session", **self.ps.stats.logging_data())
[docs] def _save_webpush_notif(self, notif): """Save a direct_update webpush style notification""" message = self.db.message_table(self.ps.message_month) return deferToThread(message.store_message, notif).addErrback(self.log_failure)
[docs] def _lookup_node(self, results): """Looks up the node to send a notify for it to check storage if connected""" # Locate the node that has this client connected d = deferToThread(self.db.router.get_uaid, self.ps.uaid) d.addCallback(self._notify_node) d.addErrback(self._trap_uaid_not_found) d.addErrback(self.log_failure, extra="Failed to get UAID for redeliver")
[docs] def _trap_uaid_not_found(self, fail): # type: (failure.Failure) -> None """Traps UAID not found error""" fail.trap(ItemNotFound)
[docs] def _notify_node(self, result): """Checks the result of lookup node to send the notify if the client is connected elsewhere now""" if not result: return node_id = result.get("node_id") if not node_id: return # If it's ourselves, we can stop if result.get("connected_at") == self.ps.connected_at: return # Send the notify to the node url = node_id + "/notif/" + self.ps.uaid d = self.factory.agent.request( "PUT", url.encode("utf8"), ).addCallback(IgnoreBody.ignore) d.addErrback(self.trap_connection_err) d.addErrback(self.trap_boto3_err) d.addErrback(self.log_failure, extra="Failed to notify node")
[docs] def returnError(self, messageType, reason, statusCode, close=True, url=DEFAULT_WS_ERR): """Return an error to a client, and optionally shut down the connection safely""" send = {"messageType": messageType, "reason": reason, "status": statusCode} if url: send["more_info"] = url self.sendJSON(send) if close: self.sendClose()
[docs] def error_overload(self, failure, message_type, disconnect=True): """Handle database overloads and errors If ``disconnect`` is False, the an overload error is returned and the client is not disconnected. Otherwise, pause producing to cease incoming notifications while we wait a random interval up to 8 seconds before closing down the connection. Most clients wait up to 10 seconds for a command, but this is not a guarantee, so rather than never reply, we still shut the connection down. :param disconnect: Whether the client should be disconnected or not. """ failure.trap(ClientError) if disconnect: self.transport.pauseProducing() d = self.deferToLater(self.randrange(4, 9), self.error_finish_overload, message_type) d.addErrback(self.trap_cancel) else: send = {"messageType": "error", "reason": "overloaded", "status": 503} self.sendJSON(send)
[docs] def error_finish_overload(self, message_type): """Close the connection down and resume consuming input after the random interval from a db overload""" # Resume producing so we can finish the shutdown self.transport.resumeProducing() self.returnError(message_type, "error - overloaded", 503)
[docs] def sendJSON(self, body): """Send a Python dict as a JSON string in a websocket message""" self.sendMessage(json.dumps(body).encode('utf8'), False)
############################################################# # Message Processing Methods #############################################################
[docs] def process_hello(self, data): """Process a hello message""" # This must be a helo, or we kick the client cmd = data.get("messageType") if cmd != "hello": return self.sendClose() if self.ps.uaid: return self.returnError("hello", "duplicate hello", 401) if not data.get("use_webpush", False): return self.returnError("hello", "Simplepush not supported", 401) self.ps.init_connection() uaid = data.get("uaid") existing_user, uaid = validate_uaid(uaid) self.ps.uaid = uaid self.ps.stats.existing_uaid = existing_user self.transport.pauseProducing() d = self.deferToThread(self._register_user, existing_user) d.addCallback(self._check_other_nodes) d.addErrback(self.trap_cancel) d.addErrback(self.error_overload, "hello") d.addErrback(self.error_hello) self.ps._register = d return d
[docs] def _register_user(self, existing_user=True): """Register a returning or new user :type existing_user: bool """ # If it's an existing user, verify the record is valid user_item = None if existing_user: user_item = self._verify_user_record() if not user_item: # No valid user record, consider this a new user self.ps.uaid = uuid.uuid4().hex self.ps.stats.uaid_reset = True user_item = dict( uaid=self.ps.uaid, node_id=self.conf.router_url, connected_at=self.ps.connected_at, router_type=self.ps.router_type, last_connect=generate_last_connect(), record_version=USER_RECORD_VERSION, ) user_item["current_month"] = self.ps.message_month return self.db.router.register_user(user_item)
[docs] def _verify_user_record(self): """Verify a user record is valid Returns a record that is ready for registering in the database if the user record was found. :rtype: :class:`~boto.dynamodb2.items.Item` or None """ try: record = self.db.router.get_uaid(self.ps.uaid) except ItemNotFound: return None # All records must have a router_type and connected_at, in some odd # cases a record exists for some users that doesn't if "router_type" not in record or "connected_at" not in record: self.log.debug(format="Dropping User", code=104, uaid_hash=self.ps.uaid_hash, uaid_record=repr(record)) tags = ['code:104'] self.metrics.increment("ua.expiration", tags=tags) self.force_retry(self.db.router.drop_user, self.ps.uaid) return None # Validate webpush records # Current month must exist and be a valid prior month if ("current_month" not in record) or record["current_month"] \ not in self.db.message_tables: self.log.debug(format="Dropping User", code=105, uaid_hash=self.ps.uaid_hash, uaid_record=repr(record)) self.force_retry(self.db.router.drop_user, self.ps.uaid) tags = ['code:105'] self.metrics.increment("ua.expiration", tags=tags) return None # Determine if message table rotation is needed if record["current_month"] != self.ps.message_month: self.ps.message_month = record["current_month"] self.ps.rotate_message_table = True # Include and update last_connect if needed, otherwise exclude if has_connected_this_month(record): del record["last_connect"] else: record["last_connect"] = generate_last_connect() # Determine if this is missing a record version if ("record_version" not in record or int(record["record_version"]) < USER_RECORD_VERSION): self.ps.reset_uaid = True # Update the node_id, connected_at for this node/connected_at record["node_id"] = self.conf.router_url record["connected_at"] = self.ps.connected_at return record
[docs] def error_hello(self, failure): """errBack for hello failures""" self.transport.resumeProducing() self.log_failure(failure) self.returnError("hello", "error", 503)
[docs] def _check_other_nodes(self, result, url=DEFAULT_WS_ERR): """callback to check other nodes for clients and send them a delete as needed""" self.transport.resumeProducing() registered, previous = result if not registered: # Registration failed msg = {"messageType": "hello", "reason": "already_connected", "status": 500, "more_info": url} self.sendJSON(msg) return # Handle dupes on the same node existing = self.factory.clients.get(self.ps.uaid) if existing: if self.ps.connected_at <= existing.ps.connected_at: self.sendClose() return else: existing.sendClose() # TODO: Remove this block, issue #245. if previous and "node_id" in previous: # Get the previous information returned from dynamodb. node_id = previous["node_id"] last_connect = previous.get("connected_at") if last_connect and node_id != self.conf.router_url: url = "%s/notif/%s/%s" % (node_id, self.ps.uaid, last_connect) d = self.factory.agent.request("DELETE", url.encode("utf8")) d.addErrback(self.trap_connection_err) d.addErrback(self.trap_boto3_err) d.addErrback(self.log_failure, extra="Failed to delete old node") self.finish_hello(previous)
[docs] def finish_hello(self, previous): """callback for successful hello message, that sends hello reply""" self.ps._register = None msg = {"messageType": "hello", "uaid": self.ps.uaid, "status": 200, "use_webpush": True} if self.autoPingInterval: msg["ping"] = self.autoPingInterval msg['env'] = self.conf.env self.factory.clients[self.ps.uaid] = self self.sendJSON(msg) self.log.debug(format="hello", uaid_hash=self.ps.uaid_hash, **self.ps.raw_agent) self.metrics.increment("ua.command.hello") self.process_notifications()
[docs] def process_notifications(self): """Run a notification check against storage""" # Bail immediately if we are closed. if self.ps._should_stop: return # Are we paused? Try again later. if self.paused: d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) return # Webpush with any outstanding storage-based must all be cleared if any(self.ps.updates_sent.values()): d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) return # Are we already running? if self.ps._notification_fetch: # Cancel the prior, last one wins self.ps._notification_fetch.cancel() self.ps._check_notifications = False self.ps._more_notifications = True d = self.deferToThread(self.webpush_fetch()) d.addCallback(self.finish_notifications) d.addErrback(self.error_notification_overload) d.addErrback(self.trap_cancel) d.addErrback(self.error_message_overload) # The following errback closes the connection. It must be the last # errback in the chain. d.addErrback(self.error_notifications) self.ps._notification_fetch = d
[docs] def webpush_fetch(self): """Helper to return an appropriate function to fetch messages""" message = self.db.message_table(self.ps.message_month) if self.ps.scan_timestamps: return partial(message.fetch_timestamp_messages, self.ps.uaid_obj, self.ps.current_timestamp) else: return partial(message.fetch_messages, self.ps.uaid_obj)
[docs] def error_notifications(self, fail): """errBack for notification check failing""" # If we error'd out on this important check, we drop the connection self.log_failure(fail) self.sendClose()
[docs] def error_notification_overload(self, fail): """errBack for provisioned errors during notification check""" fail.trap(ClientError) if (fail.value.response["Error"]["Code"] != "ProvisionedThroughputExceededException"): return fail # pragma nocover # Silently ignore the error, and reschedule the notification check # to run up to a minute in the future to distribute load farther # out d = self.deferToLater(randrange(5, 60), self.process_notifications) d.addErrback(self.trap_cancel)
[docs] def error_message_overload(self, fail): """errBack for handling excessive messages per UAID""" fail.trap(MessageOverloadException) self.force_retry(self.db.router.drop_user, self.ps.uaid) self.sendClose()
[docs] def finish_notifications(self, notifs): """callback for processing notifications from storage""" self.ps._notification_fetch = None # Are we paused, try again later if self.paused: d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) return # Process notifications differently based on webpush style or not return self.finish_webpush_notifications(notifs)
[docs] def finish_webpush_notifications(self, result): # type: (Tuple[str, List[WebPushNotification]]) -> None """WebPush notification processor""" timestamp, notifs = result # If there's a timestamp, update our current one to it if timestamp: self.ps.current_timestamp = timestamp if not notifs: # No more notifications, check timestamped? if not self.ps.scan_timestamps: # Scan for timestamped then self.ps.scan_timestamps = True d = self.deferToLater(0, self.process_notifications) d.addErrback(self.trap_cancel) return # No more notifications, and we've scanned timestamped. self.ps._more_notifications = False self.ps.scan_timestamps = False self.sent_notification_count = 0 if self.ps._check_notifications: # Told to check again, start over self.ps._check_notifications = False d = self.deferToLater(1, self.process_notifications) d.addErrback(self.trap_cancel) return # Told to reset the user? if self.ps.reset_uaid: self.force_retry(self.db.router.drop_user, self.ps.uaid) self.sendClose() # Not told to check for notifications, do we need to now rotate # the message table? if self.ps.rotate_message_table: self._rotate_message_table() return # Send out all the notifications now = int(time.time()) messages_sent = False message = self.db.message_table(self.ps.message_month) for notif in notifs: self.ps.stats.stored_retrieved += 1 # If the TTL is too old, don't deliver and fire a delete off if notif.expired(at_time=now): if not notif.sortkey_timestamp: # Delete non-timestamped messages self.force_retry(message.delete_message, notif) # nocover here as coverage gets confused on the line below # for unknown reasons continue # pragma: nocover self.ps.updates_sent[str(notif.channel_id)].append(notif) msg = notif.websocket_format() messages_sent = True self.sent_notification_count += 1 if self.sent_notification_count > self.conf.msg_limit: raise MessageOverloadException() self.emit_send_metrics(notif) self.sendJSON(msg) # Did we send any messages? if messages_sent: return # No messages sent, update the record if needed if self.ps.current_timestamp: self.force_retry( message.update_last_message_read, self.ps.uaid_obj, self.ps.current_timestamp) # Schedule a new process check self.check_missed_notifications(None)
[docs] def _rotate_message_table(self): """Function to fire off a message table copy of channels + update the router current_month entry""" self.transport.pauseProducing() d = self.deferToThread(self._monthly_transition) d.addCallback(self._finish_monthly_transition) d.addErrback(self.trap_cancel) d.addErrback(self.error_monthly_rotation_overload) d.addErrback(self.error_notifications)
[docs] def _monthly_transition(self): """Transition the client to use a new message month Utilized to migrate a users channels to a new message month and update the router record reflecting the proper month. This is a blocking function that does *not* run on the event loop. """ # Get the current channels for this month message = self.db.message_table(self.ps.message_month) _, channels = message.all_channels(self.ps.uaid) # Get the current message month cur_month = self.db.current_msg_month if channels: # Save the current channels into this months message table msg_table = self.db.message_table(cur_month) msg_table.save_channels(self.ps.uaid, channels) # Finally, update the route message month self.db.router.update_message_month(self.ps.uaid, cur_month)
[docs] def _finish_monthly_transition(self, result): """Mark the client as successfully transitioned and resume""" # Update the current month now that we've moved forward a month self.ps.message_month = self.db.current_msg_month self.ps.rotate_message_table = False self.transport.resumeProducing()
[docs] def error_monthly_rotation_overload(self, fail): """Capture overload on monthly table rotation attempt If a provision exceeded error hits while attempting monthly table rotation, schedule it all over and re-scan the messages. Normal websocket client flow is returned in the meantime. """ fail.trap(ClientError) if (fail.value.response['Error']['Code'] != "ProvisionedThroughputExceededException"): return fail # pragma nocover self.transport.resumeProducing() d = self.deferToLater(randrange(1, 30*60), self.process_notifications) d.addErrback(self.trap_cancel)
[docs] def _send_ping(self): """Helper for ping sending that tracks when the ping was sent""" self.ps.last_ping = time.time() return self.sendMessage("{}", False)
[docs] def process_ping(self): """Ping Handling Clients in the wild have a bug that lowers their ping interval to 0. It will never increase for them, as there is no way to remedy this without causing the client to use drastically more battery/data-usage we send them a code 4774 close to signify that they should stop until network change. No other client should ping more than once per minute, or we tell them to go away. """ now = time.time() last_ping_ago = now - self.ps.last_ping if last_ping_ago >= 55: self._send_ping() else: self.sendClose(code=4774)
[docs] def process_register(self, data): """Process a register message""" if "channelID" not in data: return self.bad_message("register") chid = data["channelID"] try: if str(uuid.UUID(chid)) != chid: return self.bad_message("register", "Bad UUID format, use" "lower case, dashed format") except (ValueError, TypeError): return self.bad_message("register", "Invalid UUID specified") self.transport.pauseProducing() d = self.deferToThread(self.conf.make_endpoint, self.ps.uaid, chid, data.get("key")) d.addCallback(self.finish_register, chid) d.addErrback(self.trap_cancel) d.addErrback(self.error_register) return d
[docs] def error_register(self, fail): """errBack handler for registering to fail""" self.transport.resumeProducing() msg = {"messageType": "register", "status": 500, "reason": "An unexpected server error occurred"} self.sendJSON(msg) self.log_failure(fail, extra="Failed to register")
[docs] def finish_register(self, endpoint, chid): """callback for successful endpoint creation, sends register reply""" message = self.db.message_table(self.ps.message_month) d = self.deferToThread(message.register_channel, self.ps.uaid, chid) d.addCallback(self.send_register_finish, endpoint, chid) # Note: No trap_cancel needed here since the deferred here is # returned to process_register which will trap it d.addErrback(self.error_overload, "register", disconnect=False) return d
def send_register_finish(self, result, endpoint, chid): self.transport.resumeProducing() msg = {"messageType": "register", "channelID": chid, "pushEndpoint": endpoint, "status": 200} self.sendJSON(msg) self.metrics.increment("ua.command.register") self.ps.stats.registers += 1 self.log.info(format="Register", channel_id=chid, endpoint=endpoint, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, **self.ps.raw_agent)
[docs] def process_unregister(self, data): """Process an unregister message""" if "channelID" not in data: return self.bad_message("unregister", "Missing ChannelID") chid = data["channelID"] try: uuid.UUID(chid) except ValueError: return self.bad_message("unregister", "Invalid ChannelID") self.metrics.increment("ua.command.unregister") self.ps.stats.unregisters += 1 event = dict(format="Unregister", channel_id=chid, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, **self.ps.raw_agent) if "code" in data: event["code"] = extract_code(data) self.log.info(**event) # Clear out any existing tracked messages for this channel self.ps.direct_updates[chid] = [] self.ps.updates_sent[chid] = [] # Unregister the channel message = self.db.message_table(self.ps.message_month) self.force_retry(message.unregister_channel, self.ps.uaid, chid) data["status"] = 200 self.sendJSON(data)
[docs] def ack_update(self, update): """Helper function for tracking ack'd updates Returns either None, if no delete_notification call is needed, or a deferred for the delete_notification call if it was needed. """ if not update: return chid = update.get("channelID") version = update.get("version") if not chid or not version: return code = extract_code(update) return self._handle_webpush_ack(chid, version, code)
[docs] def _handle_webpush_ack(self, chid, version, code): """Handle clearing out a webpush ack""" def ver_filter(notif): return notif.version == version found = filter( ver_filter, self.ps.direct_updates[chid] ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 self.log.debug(format="Ack", router_key="webpush", channel_id=chid, message_id=version, message_source="direct", message_size=size, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, code=code, **self.ps.raw_agent) self.ps.stats.direct_acked += 1 self.ps.direct_updates[chid].remove(msg) return found = filter( ver_filter, self.ps.updates_sent[chid] ) # type: List[WebPushNotification] if found: msg = found[0] size = len(msg.data) if msg.data else 0 self.log.debug(format="Ack", router_key="webpush", channel_id=chid, message_id=version, message_source="stored", message_size=size, uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, code=code, **self.ps.raw_agent) self.ps.stats.stored_acked += 1 message = self.db.message_table(self.ps.message_month) if msg.sortkey_timestamp: # Is this the last un-acked message we're waiting for? last_unacked = sum( len(sent) for sent in self.ps.updates_sent.itervalues() ) == 1 if (msg.sortkey_timestamp == self.ps.current_timestamp or last_unacked): # If it's the last message in the batch, or last un-acked # message d = self.force_retry( message.update_last_message_read, self.ps.uaid_obj, self.ps.current_timestamp, ) d.addBoth(self._handle_webpush_update_remove, chid, msg) else: # It's timestamped, but not the last of this batch, # so we just remove it from local tracking self._handle_webpush_update_remove(None, chid, msg) d = None else: # No sortkey_timestamp, so legacy/topic message, delete d = self.force_retry(message.delete_message, msg) # We don't remove the update until we know the delete ran # This is because we don't use range queries on dynamodb and # we need to make sure this notification is deleted from the # db before we query it again (to avoid dupes). d.addBoth(self._handle_webpush_update_remove, chid, msg) return d
[docs] def _handle_webpush_update_remove(self, result, chid, notif): """Handle clearing out the updates_sent It's possible the client may leave before this runs, so this is wrapped in a try/except in case the tear-down of self has started. """ try: self.ps.updates_sent[chid].remove(notif) except (AttributeError, ValueError): pass
[docs] def process_ack(self, data): """Process an ack message, delete notifications from storage if needed""" updates = data.get("updates") if not updates or not isinstance(updates, list): return self.metrics.increment("ua.command.ack") defers = filter(None, map(self.ack_update, updates)) if defers: self.transport.pauseProducing() dl = DeferredList(defers) dl.addBoth(self.check_missed_notifications, True) else: self.check_missed_notifications(None)
[docs] def process_nack(self, data): """Process a nack message and log its contents""" code = extract_code(data) version = data.get("version") if not version: return self.log.debug(format="Nack", uaid_hash=self.ps.uaid_hash, user_agent=self.ps.user_agent, message_id=str(version), code=code, **self.ps.raw_agent) mcode = code if code in NACK_CODES else 0 self.metrics.increment( 'ua.command.nack', tags=make_tags(code=mcode)) self.ps.stats.nacks += 1
[docs] def check_missed_notifications(self, results, resume=False): """Check to see if notifications were missed""" if resume: # Resume consuming ack's self.transport.resumeProducing() # Abort if stopped if self.ps._should_stop: return # When using webpush, we don't check again if we have outstanding # notifications if any(self.ps.updates_sent.values()): return # Should we check again? if self.ps._more_notifications: self.process_notifications() elif self.ps._check_notifications: # If we were told to check notifications, start over since we might # have missed a topic message self.ps.scan_timestamps = False self.process_notifications()
[docs] def bad_message(self, typ, message=None, url=DEFAULT_WS_ERR): """Error helper for sending a 401 status back""" msg = {"messageType": typ, "status": 401, "more_info": url} if message: msg["reason"] = message self.sendJSON(msg)
#################################### # Utility function for external use
[docs] def send_notification(self, update): """Utility function for external use This function is called by the HTTP handler to deliver an incoming update notification from an endpoint. """ chid = update["channelID"] # Create the notification notif = WebPushNotification.from_serialized(self.ps.uaid_obj, update) self.ps.direct_updates[chid].append(notif) self.emit_send_metrics(notif) self.sendJSON(notif.websocket_format())
def emit_send_metrics(self, notif): if notif.topic: self.metrics.increment("ua.notification.topic") self.metrics.increment( 'ua.message_data', notif.data_length, tags=make_tags(source=notif.source))
class PushServerFactory(WebSocketServerFactory): """PushServerProtocol factory""" protocol = PushServerProtocol def __init__(self, conf, db, agent, clients): # type: (AutopushConfig, DatabaseManager, Agent, Dict) -> None WebSocketServerFactory.__init__(self, conf.ws_url) self.conf = conf self.db = db self.agent = agent self.clients = clients self.setProtocolOptions( webStatus=False, openHandshakeTimeout=5, autoPingInterval=conf.auto_ping_interval, autoPingTimeout=conf.auto_ping_timeout, maxConnections=conf.max_connections, closeHandshakeTimeout=conf.close_handshake_timeout, )
[docs]class RouterHandler(BaseHandler): """Router Handler Handles routing a notification to a connected client from an endpoint. """
[docs] def put(self, uaid): """HTTP Put Attempt delivery of a notification to a connected client. """ client = self.application.clients.get(uaid) if not client: self.set_status(404, reason=None) self.write("Client not connected.") return if client.paused: self.set_status(503, reason=None) self.write("Client busy.") return update = json.loads(self.request.body) client.send_notification(update) self.write("Client accepted for delivery")
[docs]class NotificationHandler(BaseHandler):
[docs] def put(self, uaid, *args): """HTTP Put Notify a connected client that it should check storage for new notifications. """ client = self.application.clients.get(uaid) if not client: self.set_status(404, reason=None) self.write("Client not connected.") return if client.paused: # Client already busy waiting for stuff, flag for check client._check_notifications = True self.set_status(202) self.write("Flagged for Notification check") return # Client is online and idle, start a notification check client.process_notifications() self.metrics.increment("ua.notification_check") self.write("Notification check started")
[docs] def delete(self, uaid, connected_at): """HTTP Delete Drop a connected client as the client has connected to a new node. """ client = self.application.clients.get(uaid) if client and client.ps.connected_at == int(connected_at): client.sendClose() self.write("Terminated duplicate")
class ConnectionWSSite(Site): """The Websocket Site""" def __init__(self, conf, ws_factory): # type: (AutopushConfig, PushServerFactory) -> None self.conf = conf self.noisy = conf.debug resource = DefaultResource(WebSocketResource(ws_factory)) resource.putChild("status", StatusResource()) Site.__init__(self, resource) def ssl_cf(self): # type: () -> Optional[AutopushSSLContextFactory] """Build our SSL Factory (if configured). Configured from the ssl_key/cert/dh_param values. """ return self.conf.ssl.cf() class DefaultResource(Resource): """Delegates rendering to a default resource.""" def __init__(self, resource): Resource.__init__(self) self.resource = resource def getChild(self, path, request): return self.resource def render(self, request): # pragma: nocover return self.resource.render(request) class StatusResource(Resource): isLeaf = True def render(self, request): request.setResponseCode(200) request.setHeader("content-type", "application/json") return json.dumps({"status": "OK", "version": __version__})