import re
import time
from cryptography.fernet import InvalidToken
from cryptography.exceptions import InvalidSignature
from marshmallow import (
Schema,
fields,
pre_load,
post_load,
validates,
validates_schema,
)
from marshmallow_polyfield import PolyField
from marshmallow.validate import Equal
from twisted.logger import Logger # noqa
from twisted.internet.defer import Deferred # noqa
from twisted.internet.defer import maybeDeferred
from twisted.internet.threads import deferToThread
from typing import ( # noqa
Any,
Dict,
Optional
)
from jose import JOSEError, JWTError
from autopush.crypto_key import CryptoKey
from autopush.db import DatabaseManager # noqa
from autopush.metrics import IMetrics, make_tags # noqa
from autopush.db import hasher
from autopush.exceptions import (
InvalidRequest,
InvalidTokenException,
ItemNotFound,
VapidAuthException,
)
from autopush.types import JSONDict # noqa
from autopush.utils import (
base64url_encode,
extract_jwt,
ms_time,
WebPushNotification,
normalize_id,
parse_auth_header,
)
from autopush.web.base import (
threaded_validate,
BaseWebHandler,
PREF_SCHEME,
)
MAX_TTL = 60 * 60 * 24 * 60
# Base64 URL validation
VALID_BASE64_URL = re.compile(r'^[0-9A-Za-z\-_]+=*$')
VALID_ROUTER_TYPES = ["simplepush", "webpush", "gcm", "fcm", "apns", "adm"]
class WebPushSubscriptionSchema(Schema):
uaid = fields.UUID(required=True)
chid = fields.UUID(required=True)
public_key = fields.Raw(missing=None)
@pre_load
def extract_subscription(self, d):
try:
result = self.context["conf"].parse_endpoint(
self.context["metrics"],
token=d["token"],
version=d["api_ver"],
ckey_header=d["ckey_header"],
auth_header=d["auth_header"],
)
except (VapidAuthException) as ex:
raise InvalidRequest("missing authorization header: {}".format(ex),
status_code=401, errno=109)
except (InvalidTokenException, InvalidToken):
raise InvalidRequest("invalid token", status_code=404, errno=102)
return result
@validates_schema(skip_on_field_errors=True)
def validate_uaid_month_and_chid(self, d):
db = self.context["db"] # type: DatabaseManager
try:
result = db.router.get_uaid(d["uaid"].hex)
except ItemNotFound:
raise InvalidRequest("UAID not found", status_code=410, errno=103)
# We must have a router_type to validate the user
router_type = result.get("router_type")
if router_type not in VALID_ROUTER_TYPES:
self.context["log"].debug(format="Dropping User", code=102,
uaid_hash=hasher(result["uaid"]),
uaid_record=repr(result))
self.context["metrics"].increment(
"updates.drop_user",
tags=make_tags(errno=102))
self.context["db"].router.drop_user(result["uaid"])
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
if (router_type == "gcm"
and 'senderID' not in result.get('router_data',
{}).get("creds", {})):
# Make sure we note that this record is bad.
result['critical_failure'] = \
result.get('critical_failure', "Missing SenderID")
db.router.register_user(result)
if (router_type == "fcm"
and 'app_id' not in result.get('router_data', {})):
# Make sure we note that this record is bad.
result['critical_failure'] = \
result.get('critical_failure', "Missing SenderID")
db.router.register_user(result)
if result.get("critical_failure"):
raise InvalidRequest("Critical Failure: %s" %
result.get("critical_failure"),
status_code=410,
errno=105)
# Some stored user records are marked as "simplepush".
# If you encounter one, may need to tweak it a bit to get it as
# a valid WebPush record.
if result["router_type"] == "simplepush":
result["router_type"] = "webpush"
if result["router_type"] == "webpush":
self._validate_webpush(d, result)
# Propagate the looked up user data back out
d["user_data"] = result
def _validate_webpush(self, d, result):
db = self.context["db"] # type: DatabaseManager
log = self.context["log"] # type: Logger
metrics = self.context["metrics"] # type: IMetrics
channel_id = normalize_id(d["chid"])
uaid = result["uaid"]
if 'current_month' not in result:
log.debug(format="Dropping User", code=102,
uaid_hash=hasher(uaid),
uaid_record=repr(result))
metrics.increment(
"updates.drop_user",
tags=make_tags(errno=102))
db.router.drop_user(uaid)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
month_table = result["current_month"]
if month_table not in db.message_tables:
log.debug(format="Dropping User", code=103,
uaid_hash=hasher(uaid),
uaid_record=repr(result))
metrics.increment(
"updates.drop_user",
tags=make_tags(errno=103))
db.router.drop_user(uaid)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
msg = db.message_table(month_table)
exists, chans = msg.all_channels(uaid=uaid)
if (not exists or channel_id.lower() not
in map(lambda x: normalize_id(x), chans)):
log.debug("Unknown subscription: {channel_id}",
channel_id=channel_id)
raise InvalidRequest("No such subscription", status_code=410,
errno=106)
class WebPushBasicHeaderSchema(Schema):
authorization = fields.String()
ttl = fields.Integer(required=False, missing=None)
topic = fields.String(required=False, missing=None)
api_ver = fields.String()
@validates('topic')
def validate_topic(self, value):
if value is None:
return True
if len(value) > 32:
raise InvalidRequest("Topic must be no greater than 32 "
"characters", errno=113)
if not VALID_BASE64_URL.match(value):
raise InvalidRequest("Topic must be URL and Filename safe Base"
"64 alphabet", errno=113)
@validates('ttl')
def validate_ttl(self, value):
if value is not None and value < 0:
raise InvalidRequest("TTL must be greater than 0", errno=114)
@post_load
def cap_ttl(self, d):
if 'ttl' in d:
d["ttl"] = min(d["ttl"], MAX_TTL)
class WebPushCrypto01HeaderSchema(Schema):
"""Validates WebPush Message Encryption
Uses draft-ietf-webpush-encryption-01 rules for validation.
"""
content_encoding = fields.String(
required=True,
load_from="content-encoding",
validate=Equal("aesgcm128")
)
encryption = fields.String(required=True)
encryption_key = fields.String(
required=True,
load_from="encryption-key"
)
crypto_key = fields.String(load_from="crypto-key")
@validates("encryption")
def validate_encryption(self, value):
"""Must contain a salt value"""
salt = CryptoKey.parse_and_get_label(value, "salt")
if not salt or not VALID_BASE64_URL.match(salt):
raise InvalidRequest("Invalid salt value in Encryption header",
status_code=400,
errno=110)
@validates("crypto_key")
def validate_crypto_key(self, value):
"""Must not contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if dh:
raise InvalidRequest(
"dh value in Crypto-Key header not valid for 01 or earlier "
"webpush-encryption",
status_code=400,
errno=110,
)
@validates("encryption_key")
def validate_encryption_key(self, value):
"""Must contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if not dh or not VALID_BASE64_URL.match("dh"):
raise InvalidRequest("Invalid dh value in Encryption-Key header",
status_code=400,
errno=110)
class WebPushCrypto04HeaderSchema(Schema):
"""Validates WebPush Message Encryption
Uses draft-ietf-httpbis-encryption-encoding-04 rules for validation.
"""
content_encoding = fields.String(
required=True,
load_from="content-encoding",
validate=Equal("aesgcm")
)
encryption = fields.String(required=True)
crypto_key = fields.String(
load_from="crypto-key",
)
@validates("encryption")
def validate_encryption(self, value):
"""Must contain a salt value"""
salt = CryptoKey.parse_and_get_label(value, "salt")
if not salt or not VALID_BASE64_URL.match(salt):
raise InvalidRequest("Invalid salt value in Encryption header",
status_code=400,
errno=110)
@validates("crypto_key")
def validate_crypto_key(self, value):
"""Must contain a dh value"""
dh = CryptoKey.parse_and_get_label(value, "dh")
if not dh or not VALID_BASE64_URL.match("dh"):
raise InvalidRequest("Invalid dh value in Encryption-Key header",
status_code=400,
errno=110)
@validates_schema(pass_original=True)
def reject_encryption_key(self, data, original_data):
if "encryption-key" in original_data:
raise InvalidRequest(
"Encryption-Key header not valid for 02 or later "
"webpush-encryption",
status_code=400,
errno=110,
)
class WebPushCrypto06HeaderSchema(Schema):
"""Validates WebPush Message Encryption
Uses draft-ietf-httpbis-encryption-encoding-06 rules for validation
"""
content_encoding = fields.String(
required=True,
load_from="content-encoding",
validate=Equal("aes128gcm")
)
encryption = fields.String(required=False)
crypto_key = fields.String(required=False,
load_from="crypto-key")
@validates("encryption")
def validate_encryption(self, value):
if CryptoKey.parse_and_get_label(value, "salt"):
raise InvalidRequest("Do not include 'salt' in aes128gcm "
"Encryption header",
status_code=400,
errno=110)
@validates("crypto_key")
def validate_crypto_key(self, value):
if CryptoKey.parse_and_get_label(value, "dh"):
raise InvalidRequest("Do not include 'dh' in aes128gcm "
"Crypto-Key header",
status_code=400,
errno=110)
class WebPushInvalidContentEncodingSchema(Schema):
"""Returned to raise an Invalid Content-encoding error"""
@validates_schema
def invalid_content_encoding(self, d):
raise InvalidRequest(
"Unknown Content-Encoding",
status_code=400,
errno=110
)
def conditional_crypto_deserialize(object_dict, parent_object_dict):
"""Return the WebPush Crypto Schema if there's a data payload"""
if parent_object_dict.get("body"):
encoding = object_dict.get("content-encoding")
# Validate the crypto headers appropriately
if encoding == "aesgcm128":
return WebPushCrypto01HeaderSchema()
elif encoding == "aesgcm":
return WebPushCrypto04HeaderSchema()
elif encoding == "aes128gcm":
return WebPushCrypto06HeaderSchema()
else:
return WebPushInvalidContentEncodingSchema()
else:
return Schema()
class WebPushRequestSchema(Schema):
subscription = fields.Nested(WebPushSubscriptionSchema,
load_from="token_info")
headers = fields.Nested(WebPushBasicHeaderSchema)
crypto_headers = PolyField(
load_from="headers",
deserialization_schema_selector=conditional_crypto_deserialize,
)
body = fields.Raw()
token_info = fields.Raw()
vapid_version = fields.String(required=False, missing=None)
@validates('body')
def validate_data(self, value):
max_data = self.context["conf"].max_data
if value and len(value) > max_data:
raise InvalidRequest(
"Data payload must be smaller than {}".format(max_data),
errno=104,
)
@pre_load
def token_prep(self, d):
d["token_info"] = dict(
api_ver=d["path_kwargs"].get("api_ver"),
token=d["path_kwargs"].get("token"),
ckey_header=d["headers"].get("crypto-key", ""),
auth_header=d["headers"].get("authorization", ""),
)
return d
def validate_auth(self, d):
crypto_exceptions = [KeyError, ValueError, TypeError,
VapidAuthException]
if self.context['conf'].use_cryptography:
crypto_exceptions.append(InvalidSignature)
else:
crypto_exceptions.extend([JOSEError, JWTError, AssertionError])
auth = d["headers"].get("authorization")
needs_auth = d["token_info"]["api_ver"] == "v2"
if not needs_auth and not auth:
return
try:
vapid_auth = parse_auth_header(auth)
token = vapid_auth['t']
d["vapid_version"] = "draft{:0>2}".format(
vapid_auth['version'])
if vapid_auth['version'] == 2:
public_key = vapid_auth['k']
else:
public_key = d["subscription"].get("public_key")
jwt = extract_jwt(
token,
public_key,
is_trusted=self.context['conf'].enable_tls_auth,
use_crypto=self.context['conf'].use_cryptography
)
if not isinstance(jwt, Dict):
raise InvalidRequest("Invalid Authorization Header",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
except tuple(crypto_exceptions):
raise InvalidRequest("Invalid Authorization Header",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
if "aud" not in jwt:
raise InvalidRequest("Invalid bearer token: No Audience specified",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
if jwt['aud'] != self.context["conf"].endpoint_url:
raise InvalidRequest(
"Invalid bearer token: Invalid Audience Specified",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
if "exp" not in jwt:
raise InvalidRequest("Invalid bearer token: No expiration",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
try:
jwt_expires = int(jwt['exp'])
except (TypeError, ValueError):
raise InvalidRequest("Invalid bearer token: Invalid expiration",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
now = time.time()
jwt_has_expired = now > jwt_expires
if jwt_has_expired:
raise InvalidRequest("Invalid bearer token: Auth expired",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
jwt_too_far_in_future = (jwt_expires - now) > (60*60*24)
if jwt_too_far_in_future:
raise InvalidRequest("Invalid bearer token: Auth > 24 hours in "
"the future",
status_code=401, errno=109,
headers={"www-authenticate": PREF_SCHEME})
jwt_crypto_key = base64url_encode(public_key)
d["jwt"] = dict(jwt_crypto_key=jwt_crypto_key, jwt_data=jwt)
@post_load
def fixup_output(self, d):
# Verify authorization
# Note: This has to be done here, since schema validation takes place
# before nested schemas, and in this case we need all the nested
# schema logic to run first.
self.validate_auth(d)
# Merge crypto headers back in
if d["crypto_headers"]:
d["headers"].update(
{k.replace("_", "-"): v for k, v in
d["crypto_headers"].items()}
)
# Base64-encode data for Web Push
d["body"] = base64url_encode(d["body"])
# Set the notification based on the validated request schema data
d["notification"] = WebPushNotification.from_webpush_request_schema(
data=d, fernet=self.context["conf"].fernet,
legacy=self.context["conf"]._notification_legacy,
)
return d
[docs]class WebPushHandler(BaseWebHandler):
cors_methods = "POST"
cors_request_headers = ("content-encoding", "encryption",
"crypto-key", "ttl",
"encryption-key", "content-type",
"authorization")
cors_response_headers = ("location", "www-authenticate")
[docs] def initialize(self):
"""Must run on initialization to set ahead of validation"""
super(WebPushHandler, self).initialize()
self._handling_message = True
@threaded_validate(WebPushRequestSchema)
def post(self,
subscription, # type: Dict[str, Any]
notification, # type: WebPushNotification
jwt=None, # type: Optional[JSONDict]
**kwargs # type: Any
):
# type: (...) -> Deferred
# Store Vapid info if present
if jwt:
self.metrics.increment("updates.vapid.{}".format(
kwargs.get('vapid_version'))
)
self._client_info["jwt_crypto_key"] = jwt["jwt_crypto_key"]
for i in jwt["jwt_data"]:
self._client_info["jwt_" + i] = jwt["jwt_data"][i]
user_data = subscription["user_data"]
encoding = ''
if notification.data and notification.headers:
encoding = notification.headers.get('encoding', '')
self.metrics.increment(
"updates.notification.encoding.{}".format(encoding)
)
self._client_info.update(
message_id=notification.message_id,
uaid_hash=hasher(user_data.get("uaid")),
channel_id=notification.channel_id.hex,
router_key=user_data["router_type"],
message_size=notification.data_length,
message_ttl=notification.ttl,
version=notification.version,
encoding=encoding,
)
router_type = user_data["router_type"]
router = self.routers[router_type]
self._router_time = time.time()
d = maybeDeferred(router.route_notification, notification, user_data)
d.addCallback(self._router_completed, user_data, "",
router_type=router_type,
vapid=jwt)
d.addErrback(self._router_fail_err,
router_type=router_type,
vapid=jwt is not None,
uaid=user_data.get("uaid"))
d.addErrback(self._response_err)
return d
[docs] def _router_completed(self, response, uaid_data, warning="",
router_type=None, vapid=None):
"""Called after router has completed successfully"""
# Log the time taken for routing
self._timings["route_time"] = time.time() - self._router_time
# Were we told to update the router data?
time_diff = time.time() - self._start_time
if response.router_data is not None:
if not response.router_data:
# An empty router_data object indicates that the record should
# be deleted. There is no longer valid route information for
# this record.
self.log.debug(format="Dropping User", code=100,
uaid_hash=hasher(uaid_data["uaid"]),
uaid_record=repr(uaid_data),
client_info=self._client_info)
d = deferToThread(self.db.router.drop_user, uaid_data["uaid"])
d.addCallback(lambda x: self._router_response(response,
router_type,
vapid))
return d
# The router data needs to be updated to include any changes
# requested by the bridge system
uaid_data["router_data"] = response.router_data
# set the AWS mandatory data
uaid_data["connected_at"] = ms_time()
d = deferToThread(self.db.router.register_user, uaid_data)
response.router_data = None
d.addCallback(lambda x: self._router_completed(
response,
uaid_data,
warning,
router_type,
vapid))
return d
else:
# No changes are requested by the bridge system, proceed as normal
if response.status_code == 200 or response.logged_status == 200:
self.log.debug(format="Successful delivery",
client_info=self._client_info)
elif response.status_code == 202 or response.logged_status == 202:
self.log.debug(
format="Router miss, message stored.",
client_info=self._client_info)
self.metrics.timing("notification.request_time",
duration=time_diff)
response.response_body = (
response.response_body + " " + warning).strip()
self._router_response(response, router_type, vapid)