"""Autopush Config Object and Setup"""
import json
import socket
from argparse import Namespace # noqa
from hashlib import sha256
from typing import ( # noqa
Any,
Dict,
List,
Optional,
Type,
Union
)
from attr import (
attrs,
attrib,
Factory
)
from cryptography.fernet import Fernet, MultiFernet
from cryptography.hazmat.primitives import constant_time
import autopush.db as db
from autopush.exceptions import (
InvalidConfig,
InvalidTokenException,
VapidAuthException
)
from autopush.ssl import AutopushSSLContextFactory
from autopush.types import JSONDict # noqa
from autopush.utils import (
CLIENT_SHA256_RE,
canonical_url,
get_amid,
resolve_ip,
repad,
base64url_decode,
parse_auth_header,
)
from autopush.crypto_key import CryptoKey, CryptoKeyException
def _init_crypto_key(ck):
# type: (Optional[Union[str, List[str]]]) -> List[str]
"""Provide a default or ensure the provided's a list"""
# if CRYPTO_KEY is not set by docker, it may pass an empty string,
# which is converted into an Array element and prevents the config
# file value from being read
if ck is None or ck == ['']:
return [Fernet.generate_key()]
return ck if isinstance(ck, list) else [ck]
def _nested(cls, **kwargs):
# type: (Type, **Any) -> Any
"""Defines an attr cls nested within another attr.
This attribute constructs the nested attr from a dict argument
(representing its kwargs) unless already an instance of cls.
"""
def converter(arg):
return arg if isinstance(arg, cls) else cls(**arg)
return attrib(converter=converter, **kwargs)
[docs]@attrs
class SSLConfig(object):
"""AutopushSSLContextFactory configuration"""
key = attrib(default=None) # type: Optional[str]
cert = attrib(default=None) # type: Optional[str]
dh_param = attrib(default=None) # type: Optional[str]
[docs] def cf(self, **kwargs):
# type: (**Any) -> Optional[AutopushSSLContextFactory]
"""Build our AutopushSSLContextFactory (if configured)"""
if not self.key:
return None
return AutopushSSLContextFactory(
self.key,
self.cert,
dh_file=self.dh_param,
**kwargs
)
[docs]@attrs
class DDBTableConfig(object):
"""A DynamoDB Table's configuration"""
tablename = attrib() # type: str
read_throughput = attrib(default=5) # type: int
write_throughput = attrib(default=5) # type: int
[docs]@attrs
class AutopushConfig(object):
"""Main Autopush Settings Object"""
debug = attrib(default=False) # type: bool
fernet = attrib(init=False) # type: MultiFernet
_crypto_key = attrib(
converter=_init_crypto_key, default=None) # type: List[str]
bear_hash_key = attrib(default=Factory(list)) # type: List[str]
human_logs = attrib(default=True) # type: bool
hostname = attrib(default=None) # type: Optional[str]
port = attrib(default=None) # type: Optional[int]
_resolve_hostname = attrib(default=False) # type: bool
router_scheme = attrib(default=None) # type: Optional[str]
router_hostname = attrib(default=None) # type: Optional[str]
router_port = attrib(default=None) # type: Optional[int]
endpoint_scheme = attrib(default=None) # type: Optional[str]
endpoint_hostname = attrib(default=None) # type: Optional[str]
endpoint_port = attrib(default=None) # type: Optional[int]
proxy_protocol_port = attrib(default=None) # type: Optional[int]
memusage_port = attrib(default=None) # type: Optional[int]
statsd_host = attrib(default="localhost") # type: str
statsd_port = attrib(default=8125) # type: int
megaphone_api_url = attrib(default=None) # type: Optional[str]
megaphone_api_token = attrib(default=None) # type: Optional[str]
megaphone_poll_interval = attrib(default=30) # type: int
datadog_api_key = attrib(default=None) # type: Optional[str]
datadog_app_key = attrib(default=None) # type: Optional[str]
datadog_flush_interval = attrib(default=None) # type: Optional[int]
router_table = _nested(
DDBTableConfig,
default=dict(tablename="router")
) # type: DDBTableConfig
message_table = _nested(
DDBTableConfig,
default=dict(tablename="message")
) # type: DDBTableConfig
preflight_uaid = attrib(
default="deadbeef00000000deadbeef00000000") # type: str
ssl = _nested(SSLConfig, default=Factory(SSLConfig)) # type: SSLConfig
router_ssl = _nested(
SSLConfig, default=Factory(SSLConfig)) # type: SSLConfig
client_certs = attrib(default=None) # type: Optional[Dict[str, str]]
router_url = attrib(init=False) # type: str
endpoint_url = attrib(init=False) # type: str
ws_url = attrib(init=False) # type: str
router_conf = attrib(default=Factory(dict)) # type: JSONDict
# twisted Agent's connectTimeout
connect_timeout = attrib(default=0.5) # type: float
max_data = attrib(default=4096) # type: int
env = attrib(default='development') # type: str
ami_id = attrib(default=None) # type: Optional[str]
cors = attrib(default=False) # type: bool
hello_timeout = attrib(default=0) # type: int
# Force timeout in idle seconds
msg_limit = attrib(default=100) # type: int
auto_ping_interval = attrib(default=None) # type: Optional[int]
auto_ping_timeout = attrib(default=None) # type: Optional[int]
max_connections = attrib(default=None) # type: Optional[int]
close_handshake_timeout = attrib(default=None) # type: Optional[int]
# Generate messages per legacy rules, only used for testing to
# generate legacy data.
_notification_legacy = attrib(default=False) # type: bool
# Use the cryptography library
use_cryptography = attrib(default=False) # type: bool
# Strict-Transport-Security max age (Default 1 year in secs)
sts_max_age = attrib(default=31536000) # type: int
# Don't cache ssl.wrap_socket's SSLContexts
no_sslcontext_cache = attrib(default=False) # type: bool
# DynamoDB endpoint override
aws_ddb_endpoint = attrib(default=None) # type: str
allow_table_rotation = attrib(default=True) # type: bool
def __attrs_post_init__(self):
"""Initialize the Settings object"""
# Setup hosts/ports/urls
if not self.hostname:
self.hostname = socket.gethostname()
if self._resolve_hostname:
self.hostname = resolve_ip(self.hostname)
if not self.endpoint_hostname:
self.endpoint_hostname = self.hostname
if not self.router_hostname:
self.router_hostname = self.hostname
self.router_url = canonical_url(
self.router_scheme or 'http',
self.router_hostname,
self.router_port
)
self.endpoint_url = canonical_url(
self.endpoint_scheme or 'http',
self.endpoint_hostname,
self.endpoint_port
)
# not accurate under autoendpoint (like router_url)
self.ws_url = "{}://{}:{}/".format(
'wss' if self.ssl.key else 'ws',
self.hostname,
self.port
)
self.fernet = MultiFernet([Fernet(key) for key in self._crypto_key])
@property
def enable_tls_auth(self):
"""Whether TLS authentication w/ client certs is enabled"""
return self.client_certs is not None
[docs] @classmethod
def from_argparse(cls, ns, **kwargs):
# type: (Namespace, **Any) -> AutopushConfig
"""Create an instance from argparse/additional kwargs"""
router_conf = {}
if ns.key_hash:
db.key_hash = ns.key_hash
if ns.apns_creds:
# if you have the critical elements for each external
# router, create it
try:
router_conf["apns"] = json.loads(ns.apns_creds)
except (ValueError, TypeError):
raise InvalidConfig(
"Invalid JSON specified for APNS config options")
if ns.senderid_list:
# Create a common gcmclient
try:
sender_ids = json.loads(ns.senderid_list)
except (ValueError, TypeError):
raise InvalidConfig("Invalid JSON specified for senderid_list")
try:
# This is an init check to verify that things are
# configured correctly. Otherwise errors may creep in
# later that go unaccounted.
sender_ids[sender_ids.keys()[0]]
except (IndexError, TypeError):
raise InvalidConfig("No GCM SenderIDs specified or found.")
router_conf["gcm"] = {"ttl": ns.gcm_ttl,
"dryrun": ns.gcm_dryrun,
"max_data": ns.max_data,
"collapsekey": ns.gcm_collapsekey,
"senderIDs": sender_ids,
"endpoint": ns.gcm_endpoint}
client_certs = None
# endpoint only
if getattr(ns, 'client_certs', None):
try:
client_certs_arg = json.loads(ns.client_certs)
except (ValueError, TypeError):
raise InvalidConfig("Invalid JSON specified for client_certs")
if client_certs_arg:
if not ns.ssl_key:
raise InvalidConfig("client_certs specified without SSL "
"enabled (no ssl_key specified)")
client_certs = {}
for name, sigs in client_certs_arg.iteritems():
if not isinstance(sigs, list):
raise InvalidConfig(
"Invalid JSON specified for client_certs")
for sig in sigs:
sig = sig.upper()
if (not name or not CLIENT_SHA256_RE.match(sig) or
sig in client_certs):
raise InvalidConfig(
"Invalid client_certs argument")
client_certs[sig] = name
if ns.fcm_creds:
try:
router_conf["fcm"] = {
"version": ns.fcm_version,
"ttl": ns.fcm_ttl,
"dryrun": ns.fcm_dryrun,
"max_data": ns.max_data,
"collapsekey": ns.fcm_collapsekey,
"creds": json.loads(ns.fcm_creds)
}
if not router_conf["fcm"]["creds"]:
raise InvalidConfig(
"Empty credentials for FCM config options"
)
for creds in router_conf["fcm"]["creds"].values():
if "auth" not in creds:
raise InvalidConfig(
"Missing auth for FCM config options"
)
except (ValueError, TypeError):
raise InvalidConfig(
"Invalid JSON specified for FCM config options"
)
if ns.adm_creds:
# Create a common admclient
try:
router_conf["adm"] = json.loads(ns.adm_creds)
except (ValueError, TypeError):
raise InvalidConfig(
"Invalid JSON specified for ADM config options")
ami_id = None
# Not a fan of double negatives, but this makes more
# understandable args
if not ns.no_aws:
ami_id = get_amid() or "Unknown"
allow_table_rotation = not ns.no_table_rotation
return cls(
crypto_key=ns.crypto_key,
datadog_api_key=ns.datadog_api_key,
datadog_app_key=ns.datadog_app_key,
datadog_flush_interval=ns.datadog_flush_interval,
hostname=ns.hostname,
statsd_host=ns.statsd_host,
statsd_port=ns.statsd_port,
router_conf=router_conf,
resolve_hostname=ns.resolve_hostname,
ami_id=ami_id,
client_certs=client_certs,
msg_limit=ns.msg_limit,
connect_timeout=ns.connection_timeout,
memusage_port=ns.memusage_port,
use_cryptography=ns.use_cryptography,
no_sslcontext_cache=ns._no_sslcontext_cache,
router_table=dict(
tablename=ns.router_tablename,
read_throughput=ns.router_read_throughput,
write_throughput=ns.router_write_throughput
),
message_table=dict(
tablename=ns.message_tablename,
read_throughput=ns.message_read_throughput,
write_throughput=ns.message_write_throughput
),
ssl=dict(
key=ns.ssl_key,
cert=ns.ssl_cert,
dh_param=ns.ssl_dh_param
),
sts_max_age=ns.sts_max_age,
allow_table_rotation=allow_table_rotation,
**kwargs
)
[docs] def make_endpoint(self, uaid, chid, key=None):
"""Create an v1 or v2 WebPush endpoint from the identifiers.
Both endpoints use bytes instead of hex to reduce ID length.
v1 is the uaid + chid
v2 is the uaid + chid + sha256(key).bytes
:param uaid: User Agent Identifier
:param chid: Channel or Subscription ID
:param key: Optional Base64 URL-encoded application server key
:returns: Push endpoint
"""
root = self.endpoint_url + '/wpush/'
base = (uaid.replace('-', '').decode("hex") +
chid.replace('-', '').decode("hex"))
if key is None:
return root + 'v1/' + self.fernet.encrypt(base).strip('=')
raw_key = base64url_decode(key.encode('utf8'))
ep = self.fernet.encrypt(base + sha256(raw_key).digest()).strip('=')
return root + 'v2/' + ep
[docs] def parse_endpoint(self, metrics, token, version="v1", ckey_header=None,
auth_header=None):
"""Parse an endpoint into component elements of UAID, CHID and optional
key hash if v2
:param token: The obscured subscription data.
:param version: This is the API version of the token.
:param ckey_header: the Crypto-Key header bearing the public key
(from Crypto-Key: p256ecdsa=)
:param auth_header: The Authorization header bearing the VAPID info
:raises ValueError: In the case of a malformed endpoint.
:returns: a dict containing (uaid=UAID, chid=CHID, public_key=KEY)
"""
token = self.fernet.decrypt(repad(token).encode('utf8'))
public_key = None
if ckey_header:
try:
crypto_key = CryptoKey(ckey_header)
except CryptoKeyException:
raise InvalidTokenException("Invalid key data")
public_key = crypto_key.get_label('p256ecdsa')
if auth_header:
vapid_auth = parse_auth_header(auth_header)
if not vapid_auth:
raise VapidAuthException("Invalid Auth token")
metrics.increment("notification.auth",
tags="vapid:{version},scheme:{scheme}".format(
**vapid_auth
).split(","))
# pull the public key from the VAPID auth header if needed
try:
if vapid_auth['version'] != 1:
public_key = vapid_auth['k']
except KeyError:
raise VapidAuthException("Missing Public Key")
if version == 'v1' and len(token) != 32:
raise InvalidTokenException("Corrupted push token")
if version == 'v2':
if not auth_header:
raise VapidAuthException("Missing Authorization Header")
if len(token) != 64:
raise InvalidTokenException("Corrupted push token")
if not public_key:
raise VapidAuthException("Invalid key data")
try:
decoded_key = base64url_decode(public_key)
except TypeError:
raise VapidAuthException("Invalid key data")
if not constant_time.bytes_eq(sha256(decoded_key).digest(),
token[32:]):
raise VapidAuthException("Key mismatch")
return dict(uaid=token[:16].encode('hex'),
chid=token[16:32].encode('hex'),
version=version,
public_key=public_key)