Source code for autopush.web.registration

import re
import uuid

from marshmallow_polyfield import PolyField
from typing import (  # noqa
    Optional,
    Set,
    Tuple
)

import simplejson as json
from cryptography.hazmat.primitives import constant_time
from marshmallow import (
    Schema,
    fields,
    pre_load,
    post_load,
    validate,
    validates,
    validates_schema,
)
from twisted.internet.defer import Deferred  # noqa
from twisted.internet.threads import deferToThread

from autopush.db import generate_last_connect, hasher
from autopush.exceptions import InvalidRequest, ItemNotFound, RouterException
from autopush.types import JSONDict  # noqa
from autopush.utils import generate_hash, ms_time
from autopush.web.base import (
    threaded_validate,
    BaseWebHandler,
    PREF_SCHEME,
    AUTH_SCHEMES
)


class PathUUID(fields.Field):
    """A specialized UUID deserializer for when a UUID is in the path to
    throw a 404"""
    def _deserialize(self, value, attr, data):
        try:
            return uuid.UUID(value)
        except ValueError:
            raise InvalidRequest("Invalid Path", status_code=404)


#############################################################
# Payload body validation
#############################################################
class SubInfoSchema(Schema):
    """Validates key/channelID if present"""
    key = fields.Str(allow_none=True)
    channelID = fields.UUID(allow_none=True)

    @pre_load
    def load_body(self, value):
        try:
            return json.loads(value)
        except ValueError:
            raise InvalidRequest("Invalid Request body", status_code=400,
                                 errno=108)

    @post_load
    def convert_chid(self, data):
        if "channelID" in data:
            # ALWAYS RETURN CHID AS .HEX, NO DASH
            data["channelID"] = data["channelID"].hex


class TokenSchema(SubInfoSchema):
    """Filters allowed values from body data"""
    token = fields.Str(allow_none=True)


valid_token = validate.Regexp("^[^ ]{8,}$")
valid_adm_token = validate.Regexp("^amzn1.adm-registration.v3.[^ ]{256,}$")


class GCMTokenSchema(SubInfoSchema):
    token = fields.Str(allow_none=False,
                       validate=valid_token,
                       error="Missing required token value")


class APNSTokenSchema(SubInfoSchema):
    token = fields.Str(allow_none=False,
                       validate=valid_token,
                       error="Missing required token value")
    aps = fields.Dict(allow_none=True)


class ADMTokenSchema(SubInfoSchema):
    token = fields.Str(allow_none=False,
                       validate=valid_adm_token,
                       error="Missing required token value")


#############################################################
# URI argument validation
#############################################################
class TypeAppSchema(Schema):
    """Validates that type / app_id for the URI are valid"""
    router_type = fields.Str(required=True, load_from="type")
    app_id = fields.Str(required=True)

    @validates("router_type")
    def validate_router_type(self, value):
        if value not in self.context['routers']:
            raise InvalidRequest("Invalid router",
                                 status_code=400,
                                 errno=108)


class TypeAppUaidSchema(TypeAppSchema):
    """Validates that type / app_id / uaid for the URI are valid"""
    uaid = PathUUID(required=True)

    @validates("uaid")
    def validate_uaid(self, value):
        try:
            self.context['db'].router.get_uaid(value.hex)
        except ItemNotFound:
            raise InvalidRequest("UAID not found", status_code=410, errno=103)


class TypeAppChidUaidSchema(TypeAppUaidSchema):
    """Validates that type / app_id / uaid / chid for the URI are valid"""
    chid = PathUUID(required=True)


#############################################################
# Header validation
#############################################################
class AuthorizationHeaderSchema(Schema):
    authorization = fields.Str(load_from="Authorization")


#############################################################
# Request validation mix-ins
#############################################################
class AuthorizationCheckSchema(Schema):
    headers = fields.Nested(AuthorizationHeaderSchema)

    """Schema that does the authorization check"""
    @validates_schema(skip_on_field_errors=True)
    def validate_auth(self, data):
        request_pref_header = {'www-authenticate': PREF_SCHEME}
        auth = data["headers"].get("authorization")
        if not auth:
            raise InvalidRequest("Unauthorized", status_code=401, errno=109,
                                 headers=request_pref_header)

        try:
            auth_type, auth_token = re.sub(
                r' +', ' ', auth.strip()).split(" ", 2)
        except ValueError:
            raise InvalidRequest("Invalid Authentication", status_code=401,
                                 errno=109,
                                 headers=request_pref_header)
        if auth_type.lower() not in AUTH_SCHEMES:
            raise InvalidRequest("Invalid Authentication",
                                 status_code=401,
                                 errno=109,
                                 headers=request_pref_header)

        conf = self.context['conf']
        uaid = data["path_kwargs"]["uaid"]
        if conf.bear_hash_key:
            is_valid = False
            for key in conf.bear_hash_key:
                test_token = generate_hash(key, uaid.hex)
                is_valid |= constant_time.bytes_eq(bytes(test_token),
                                                   bytes(auth_token))
            if not is_valid:
                raise InvalidRequest("Invalid Authentication",
                                     status_code=401,
                                     errno=109,
                                     headers=request_pref_header)


def conditional_token_check(object_dict, parent_dict):
    ptype = parent_dict['path_kwargs']['type']
    # Basic "bozo-filter" to prevent customer surprises later.
    if ptype not in ['apns', 'fcm', 'gcm', 'webpush', 'adm', 'test']:
        raise InvalidRequest("Unknown registration type",
                             status_code=400,
                             errno=108,
                             )
    if ptype in ['gcm', 'fcm']:
        return GCMTokenSchema()
    if ptype == 'apns':
        return APNSTokenSchema()
    if ptype == 'adm':
        return ADMTokenSchema()
    return TokenSchema()


class RouterDataSchema(Schema):
    router_data = PolyField(
        load_from="body",
        deserialization_schema_selector=conditional_token_check)

    @validates_schema(skip_on_field_errors=True)
    def register_router(self, data):
        router_type = data["path_kwargs"]["router_type"]
        router = self.context["routers"][router_type]
        try:
            router.register(uaid="", router_data=data["router_data"],
                            app_id=data["path_kwargs"]["app_id"])
        except RouterException as exc:
            raise InvalidRequest(str(exc), status_code=exc.status_code,
                                 errno=exc.errno, headers=exc.headers)


#############################################################
# Handler method validators
#############################################################
class NewRegistrationSchema(RouterDataSchema):
    path_kwargs = fields.Nested(TypeAppSchema)

    @post_load
    def extract_needed(self, data):
        return dict(
            router_type=data["path_kwargs"]["router_type"],
            router_data=data["router_data"]
        )


class GetUaidChannelSchema(AuthorizationCheckSchema):
    path_kwargs = fields.Nested(TypeAppUaidSchema)

    @post_load
    def extract_needed(self, data):
        return dict(
            uaid=data["path_kwargs"]["uaid"],
        )


class TokenUpdateSchema(AuthorizationCheckSchema, RouterDataSchema):
    path_kwargs = fields.Nested(TypeAppUaidSchema)

    @post_load
    def extract_needed(self, data):
        return dict(
            router_data=data["router_data"],
            router_type=data["path_kwargs"]["router_type"],
            uaid=data["path_kwargs"]["uaid"],
        )


class UnregisterUaidSchema(AuthorizationCheckSchema):
    path_kwargs = fields.Nested(TypeAppUaidSchema)

    @post_load
    def extract_needed(self, data):
        return dict(
            uaid=data["path_kwargs"]["uaid"],
        )


class NewChidSchema(AuthorizationCheckSchema):
    body = fields.Nested(SubInfoSchema)
    path_kwargs = fields.Nested(TypeAppUaidSchema)

    @post_load
    def extract_needed(self, data):
        chid = data["body"].get("channelID", uuid.uuid4().hex)
        return dict(
            uaid=data["path_kwargs"]["uaid"],
            chid=chid,
            app_server_key=data["body"].get("key"),
        )


class UnregisterChidSchema(AuthorizationCheckSchema):
    path_kwargs = fields.Nested(TypeAppChidUaidSchema)

    @post_load
    def extract_needed(self, data):
        return dict(
            uaid=data["path_kwargs"]["uaid"],
            chid=data["path_kwargs"]["chid"],
        )
#############################################################


class BaseRegistrationHandler(BaseWebHandler):
    """Common registration handler methods"""
    def base_tags(self):
        tags = list(self._base_tags)
        tags.append("user_agent:%s" %
                    self.request.headers.get("user-agent"))
        tags.append("host:%s" % self.request.host)
        return tags

    def _register_channel(self, uaid, chid, app_server_key):
        # type: (uuid.UUID, str, Optional[str]) -> str
        """Register a new channel and create/return its endpoint"""
        self.db.message.register_channel(uaid.hex, chid)
        return self.conf.make_endpoint(uaid.hex, chid, app_server_key)

    def _register_user(self, uaid, router_type, router_data):
        # type: (uuid.UUID, str, JSONDict) -> None
        """Save a new user record

        We set the expiry to 0 here because mobile users never check back,
        so the record may be incorrectly expired. (Expiry records older than
        5 years are not automatically deleted.)
        """
        self.db.router.register_user(dict(
            uaid=uaid.hex,
            router_type=router_type,
            router_data=router_data,
            connected_at=ms_time(),
            last_connect=generate_last_connect(),
            expiry=0
        ))

    def _write_endpoint(self, endpoint, uaid, chid, router_type, router_data,
                        new_uaid=False):
        # type: (str, uuid.UUID, str, str, JSONDict, bool) -> None
        """Write the JSON response of the created endpoint"""
        response = dict(channelID=chid, endpoint=endpoint)
        if new_uaid:
            secret = None
            if self.conf.bear_hash_key:
                secret = generate_hash(self.conf.bear_hash_key[0], uaid.hex)
            response.update(uaid=uaid.hex, secret=secret)
            # Apply any router specific fixes to the outbound response.
            router = self.routers[router_type]
            router.amend_endpoint_response(response, router_data)
        self.set_header("Content-Type", "application/json")
        self.write(json.dumps(response))
        self.log.info("Register",
                      client_info=self._client_info,
                      endpoint=endpoint,
                      uaid_hash=hasher(uaid.hex))
        self.finish()

    def _success(self, result):
        """Writes out empty 200 response"""
        self.set_header("Content-Type", "application/json")
        self.write("{}")
        self.finish()


[docs]class NewRegistrationHandler(BaseRegistrationHandler): """Handle new bridge uaid registrations""" cors_methods = "POST"
[docs] @threaded_validate(NewRegistrationSchema) def post(self, router_type, router_data): # type: (str, JSONDict) -> Deferred """HTTP POST Router type/data registration. """ self.metrics.increment("ua.command.register", tags=self.base_tags()) uaid = uuid.uuid4() # ALWAYS RETURN CHID AS .HEX, NO DASH chid = router_data.get("channelID", uuid.uuid4().hex) d = deferToThread(self._register_user_and_channel, uaid, chid, router_type, router_data) d.addCallback(self._write_endpoint, uaid, chid, router_type, router_data, new_uaid=True) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) return d
[docs] def _register_user_and_channel(self, uaid, chid, router_type, router_data): # type: (uuid.UUID, str, str, JSONDict) -> str """Register a new user/channel, return its endpoint""" self._register_user(uaid, router_type, router_data) return self._register_channel(uaid, chid, router_data.get("key"))
[docs]class UaidRegistrationHandler(BaseRegistrationHandler): """Handles UAID bridge methods""" cors_methods = "GET,POST,PUT,DELETE"
[docs] @threaded_validate(GetUaidChannelSchema) def get(self, uaid): # type: (uuid.UUID) -> Deferred """HTTP GET Return a list of known channelIDs for a given UAID """ d = deferToThread(self.db.message.all_channels, str(uaid)) d.addCallback(self._write_channels, uaid) d.addErrback(self._uaid_not_found_err) d.addErrback(self._response_err) return d
[docs] @threaded_validate(TokenUpdateSchema) def put(self, router_type, router_data, uaid): # type: (str, JSONDict, uuid.UUID) -> Deferred """HTTP PUT Update router type/data for a UAID. """ d = deferToThread(self._register_user, uaid, router_type, router_data) d.addCallback(self._success) d.addErrback(self._router_fail_err) d.addErrback(self._response_err) return d
post = put
[docs] @threaded_validate(UnregisterUaidSchema) def delete(self, uaid): # type: (uuid.UUID) -> Deferred """HTTP DELETE Delete all pending records for the given UAID """ d = deferToThread(self._delete_uaid, uaid) d.addCallback(self._success) d.addErrback(self._uaid_not_found_err) d.addErrback(self._response_err) return d
def _delete_uaid(self, uaid): self.log.debug(format="Dropping User", code=101, uaid_hash=hasher(uaid.hex)) if not self.db.router.drop_user(uaid.hex): raise ItemNotFound("UAID not found")
[docs] def _uaid_not_found_err(self, fail): """errBack for uaid lookup not finding the user""" fail.trap(ItemNotFound) self.log.debug(format="UAID not found in AWS.", status_code=410, errno=103, client_info=self._client_info) self._write_response(410, errno=103, message="Endpoint has expired. " "Do not send messages to this endpoint.")
def _write_channels(self, channel_info, uaid): # type: (Tuple[bool, Set[str]], uuid.UUID) -> None response = dict( uaid=uaid.hex, channelIDs=[str(uuid.UUID(x)) for x in channel_info[1]] ) self.set_header("Content-Type", "application/json") self.write(json.dumps(response)) self.finish()
[docs]class SubRegistrationHandler(BaseRegistrationHandler): """Handle a new bridge channel id registration for a bridge user""" cors_methods = "POST" @threaded_validate(NewChidSchema) def post(self, uaid, chid, app_server_key=None): # type: (uuid.UUID, str, Optional[str]) -> Deferred d = deferToThread(self._register_channel, uaid, chid, app_server_key) d.addCallback(self._write_endpoint, uaid, chid, "", {}) d.addErrback(self._response_err) return d
[docs]class ChannelRegistrationHandler(BaseRegistrationHandler): """Handle deleting a channel for a bridge user""" cors_methods = "DELETE" @threaded_validate(UnregisterChidSchema) def delete(self, uaid, chid): # type: (uuid.UUID, uuid.UUID) -> Deferred self.metrics.increment("ua.command.unregister", tags=self.base_tags()) d = deferToThread(self._delete_channel, uaid, chid) d.addCallback(self._success) d.addErrback(self._chid_not_found_err) d.addErrback(self._response_err) return d def _delete_channel(self, uaid, chid): # type: (uuid.UUID, uuid.UUID) -> None if not self.db.message.unregister_channel(uaid.hex, chid.hex): raise ItemNotFound("ChannelID not found") self.log.info("Unregister", client_info=self._client_info, channel_id=chid.hex, uaid_hash=hasher(uaid.hex))
[docs] def _chid_not_found_err(self, fail): """errBack for unknown chid""" fail.trap(ItemNotFound, ValueError) self.log.debug(format="CHID not found in AWS.", status_code=410, errno=106, **self._client_info) self._write_response(410, 106, message="Invalid endpoint.")