'''
This module provides the capability to from an abstract Paypal name,
such as "paymentserv", or "occ-conf" to an open connection.
The main entry point is the ConnectionManager.get_connection().
This function will promptly either:
1- raise an Exception which is a subclass of socket.error
2- return a socket
ConnectionManagers provide the following services:
1- name resolution ("paymentserv" to actual ip/port from topos)
2- transient markdown (keeping track of connection failures)
3- socket throttling (keeping track of total open sockets)
4- timeouts (connection and read timeouts from opscfg)
5- protecteds
In addition, by routing all connections through ConnectionManager,
future refactorings/modifications will be easier. For example,
fallbacks or IP multi-plexing.
'''
import time
import datetime
import socket
import random
import weakref
import collections
import gevent.socket
import gevent.ssl
import gevent.resolver_thread
import gevent
import async
import context
import socket_pool
from crypto import SSLContext
import ll
ml = ll.LLogger()
Address = collections.namedtuple('Address', 'ip port')
KNOWN_KEYS = ("connect_timeout_ms", "response_timeout_ms", "max_connect_retry",
"transient_markdown_enabled", "markdown")
ConnectInfo = collections.namedtuple("ConnectInfo", KNOWN_KEYS)
DEFAULT_CONNECT_INFO = ConnectInfo(5000, 30000, 1, False, False)
[docs]class ConnectionManager(object):
def __init__(self,
address_groups=None,
address_aliases=None,
ssl_context=None):
self.sockpools = weakref.WeakKeyDictionary() # one socket pool per ssl
# self.sockpools = {weakref(ssl_ctx): {socket_type: [list of sockets]}}
self.address_groups = address_groups
self.address_aliases = address_aliases
self.ssl_context = ssl_context
self.server_models = ServerModelDirectory()
# map of user-level socket objects to MonitoredSocket instances
self.user_socket_map = weakref.WeakKeyDictionary()
# we need to use gevent.spawn instead of async.spawn because
# at the time the connection manager is constructed, the support
# context is not yet fully initialized
self.culler = gevent.spawn(self.cull_loop)
[docs] def get_connection(
self, name_or_addr, ssl=False, sock_type=None, read_timeout=None):
'''
name_or_addr - the logical name to connect to, e.g. "db-r"
ssl - if set to True, wrap socket with context.protected;
if set to an SSL context, wrap socket with that
sock_type - a type to wrap the socket in; the intention here is for protocols
that want to run asynchronous keep-alives, or higher level handshaking
(strictly speaking, this is just a callable which accepts socket and
returns the thing that should be pooled, but for must uses it will
probably be a class)
'''
ctx = context.get_context()
address_aliases = self.address_aliases or ctx.address_aliases
#ops_config = self.ops_config or ctx.ops_config
#### POTENTIAL ISSUE: OPS CONFIG IS MORE SPECIFIC THAN ADDRESS (owch)
if isinstance(gevent.get_hub().resolver, gevent.resolver_thread.Resolver):
gevent.get_hub().resolver = _Resolver() # avoid pointless thread dispatches
if name_or_addr in address_aliases:
name_or_addr = address_aliases[name_or_addr]
if isinstance(name_or_addr, basestring): # string means a name
name = name_or_addr
address_list = self.get_all_addrs(name)
else:
address_list = [name_or_addr]
# default to a string-ification of ip for the name
name = address_list[0][0].replace('.', '-')
#if name:
# sock_config = ops_config.get_endpoint_config(name)
#else:
# sock_config = ops_config.get_endpoint_config()
sock_config = DEFAULT_CONNECT_INFO
# ensure all DNS resolution is completed; past this point
# everything is in terms of ips
def get_gai(e):
name = e[0].replace(".","-")
with ctx.log.info('DNS', name) as _log:
gai = gevent.socket.getaddrinfo(*e, family=gevent.socket.AF_INET)[0][4]
context.get_context().name_cache[e] = (time.time(), gai)
return gai
def cache_gai(e):
if context.get_context().name_cache.has_key(e):
age, value = context.get_context().name_cache[e]
if time.time() - age > 600:
async.spawn(get_gai, e)
return value
else:
return get_gai(e)
with ctx.log.get_logger('DNS.CACHE').info(name) as _log:
_log['len'] = len(address_list)
address_list = [cache_gai(e) for e in address_list]
with ctx.log.get_logger('COMPACT').info(name):
self._compact(address_list, name)
errors = []
for address in address_list:
try:
log_name = '%s:%s' % (name, address[0])
with ctx.log.get_logger('CONNECT').info(log_name) as _log:
s = self._connect_to_address(
name, ssl, sock_config, address, sock_type, read_timeout)
if hasattr(s, 'getsockname'):
_log["lport"] = s.getsockname()[1]
elif hasattr(s, '_sock'):
_log["lport"] = s._sock.getsockname()[1]
return s
except socket.error as err:
if len(address_list) == 1:
raise
ml.ld("Connection err {0!r}, {1}, {2!r}", address, name, err)
errors.append((address, err))
raise MultiConnectFailure(errors)
[docs] def get_all_addrs(self, name):
'''
returns the all addresses which the logical name would resolve to,
or raises NameNotFound if there is no known address for the given name
'''
ctx = context.get_context()
address_groups = self.address_groups or ctx.address_groups
try:
address_list = list(address_groups[name])
except KeyError:
err_str = "no address found for name {0}".format(name)
if ctx.stage_ip is None:
err_str += " (no stage communication configured; did you forget?)"
raise NameNotFound(err_str)
return address_list
[docs] def get_addr(self, name):
'''
returns the first address which the logical name would resolve to,
equivalent to get_all_addrs(name)[0]
'''
return self.get_all_addrs(name)[0]
def _connect_to_address(
self, name, ssl, sock_config, address, sock_type, read_timeout):
'''
internal helper function that does all the complex bits of establishing
a connection, keeping statistics on connections, handling markdowns
'''
ctx = context.get_context()
if address not in self.server_models:
self.server_models[address] = ServerModel(address)
server_model = self.server_models[address]
if ssl:
if ssl is True:
ssl_context = self.ssl_context or ctx.ssl_context
if ssl_context is None:
raise EnvironmentError("Unable to make protected connection to " +
repr(name or "unknown") + " at " + repr(address) +
" with no SSLContext loaded.")
elif isinstance(ssl, SSLContext):
protected = ssl
elif ssl == PLAIN_SSL:
protected = PLAIN_SSL_PROTECTED
else:
protected = NULL_PROTECTED # something falsey and weak-refable
if protected not in self.sockpools:
self.sockpools[protected] = {}
if sock_type not in self.sockpools[protected]:
idle_timeout = getattr(sock_type, "idle_timeout", 0.25)
self.sockpools[protected][sock_type] = socket_pool.SocketPool(timeout=idle_timeout)
sock = self.sockpools[protected][sock_type].acquire(address)
msock = None
new_sock = False
if not sock:
if sock_config.transient_markdown_enabled:
last_error = server_model.last_error
if last_error and time.time() - last_error < TRANSIENT_MARKDOWN_DURATION:
raise MarkedDownError()
failed = 0
sock_state = None
# is the connection within the data-center?
# use tighter timeouts if so; using the presence of a
# protected connection as a rough heuristic for now
internal = (ssl and ssl != PLAIN_SSL) or 'mayfly' in name
new_sock = False
while True:
try:
ml.ld("CONNECTING...")
sock_state = ctx.markov_stats['socket.state.' + str(address)].make_transitor('connecting')
log_name = str(address[0]) + ":" + str(address[1])
with ctx.log.get_logger('CONNECT.TCP').info(log_name) as _log:
timeout = sock_config.connect_timeout_ms / 1000.0
if internal: # connect timeout of 50ms inside the data center
timeout = min(timeout, ctx.datacenter_connect_timeout)
sock = gevent.socket.create_connection(address, timeout)
_log['timeout'] = timeout
sock_state.transition('connected')
new_sock = True
ml.ld("CONNECTED local port {0!r}/FD {1}", sock.getsockname(), sock.fileno())
if ssl: # TODO: how should SSL failures interact with markdown & connect count?
sock_state.transition('ssl_handshaking')
with ctx.log.get_logger('CONNECT.SSL').info(log_name) as _log:
if ssl == PLAIN_SSL:
sock = gevent.ssl.wrap_socket(sock)
else:
sock = async.wrap_socket_context(sock, protected.ssl_client_context)
sock_state.transition('ssl_established')
break
except socket.error as err:
if sock_state:
sock_state.transition('closed_error')
if failed >= sock_config.max_connect_retry:
server_model.last_error = time.time()
if sock_config.transient_markdown_enabled:
ctx = context.get_context()
ctx.intervals['net.markdowns.' + str(name) + '.' +
str(address[0]) + ':' + str(address[1])].tick()
ctx.intervals['net.markdowns'].tick()
ctx.log.get_logger('error').critical('TMARKDOWN').error(name=str(name),
addr=str(address))
# was event: ('ERROR', 'TMARKDOWN', '2', 'name=' + str(name) + '&addr=' + str(address))
ml.ld("Connection err {0!r}, {1}, {2!r}", address, name, err)
raise
failed += 1
msock = MonitoredSocket(sock, server_model.active_connections, protected,
name, sock_type, sock_state)
server_model.sock_in_use(msock)
if sock_type:
if getattr(sock_type, "wants_protected", False):
sock = sock_type(msock, protected)
else:
sock = sock_type(msock)
else:
sock = msock
if read_timeout is None:
sock.settimeout(sock_config.response_timeout_ms / 1000.0)
else:
sock.settimeout(read_timeout)
if msock and sock is not msock:
# if sock == msock, collection will not work
self.user_socket_map[sock] = weakref.proxy(msock)
self.user_socket_map.get(sock, sock).state.transition('in_use')
sock.new_sock = new_sock
return sock
[docs] def release_connection(self, sock):
# fetch MonitoredSocket
msock = self.user_socket_map.get(sock, sock)
# check the connection for updating of SSL cert (?)
msock.state.transition('pooled')
if context.get_context().sockpool_enabled:
self.sockpools[msock._protected][msock._type].release(sock)
else:
async.killsock(sock)
[docs] def cull_loop(self):
while 1:
for pool in sum([e.values() for e in self.sockpools.values()], []):
async.sleep(CULL_INTERVAL)
pool.cull()
async.sleep(CULL_INTERVAL)
def _compact(self, address_list, name):
'''
try to compact and make room for a new socket connection to one of
address_list raises OutOfSockets() if unable to make room
'''
ctx = context.get_context()
sock_log = ctx.log.get_logger('NET.SOCKET')
all_pools = sum([e.values() for e in self.sockpools.values()], [])
with ctx.log.get_logger('CULL').info(name) as _log:
_log['len'] = len(all_pools)
for pool in all_pools:
pool.cull()
total_num_in_use = sum([len(model.active_connections)
for model in self.server_models.values()])
if total_num_in_use >= GLOBAL_MAX_CONNECTIONS:
sock_log.critical('GLOBAL_MAX').success('culling sockets',
limit=GLOBAL_MAX_CONNECTIONS,
in_use=total_num_in_use)
# try to cull sockets to make room
made_room = False
for pool in all_pools:
if pool.total_sockets:
made_room = True
gevent.joinall(pool.reduce_size(pool.total_sockets / 2))
if not made_room:
ctx.intervals['net.out_of_sockets'].tick()
raise OutOfSockets("maximum global socket limit {0} hit: {1}".format(
GLOBAL_MAX_CONNECTIONS, total_num_in_use))
num_in_use = sum([len(self.server_models[address].active_connections)
for address in address_list])
if num_in_use >= MAX_CONNECTIONS:
sock_log.critical('ADDR_MAX').success('culling {addr} sockets',
limit=GLOBAL_MAX_CONNECTIONS,
in_use=total_num_in_use,
addr=repr(address_list))
# try to cull sockets
made_room = False
for pool in all_pools:
for address in address_list:
num_pooled = pool.socks_pooled_for_addr(address)
if num_pooled:
gevent.joinall(pool.reduce_addr_size(address, num_pooled / 2))
made_room = True
if not made_room:
ctx.intervals['net.out_of_sockets'].tick()
ctx.intervals['net.out_of_sockets.' + str(name)].tick()
raise OutOfSockets("maximum sockets for {0} already in use: {1}".format(
name, num_in_use))
return
CULL_INTERVAL = 1.0
# something falsey, and weak-ref-able
NULL_PROTECTED = type("NullProtected", (object,), {'__nonzero__': lambda self: False})()
# a marker for doing plain ssl with no protected
PLAIN_SSL = "PLAIN_SSL"
PLAIN_SSL_PROTECTED = type("PlainSslProtected", (object,), {})()
# TODO: better sources for this?
TRANSIENT_MARKDOWN_DURATION = 10.0 # seconds
try:
import resource
MAX_CONNECTIONS = int(0.8 * resource.getrlimit(resource.RLIMIT_NOFILE)[0])
GLOBAL_MAX_CONNECTIONS = MAX_CONNECTIONS
except:
MAX_CONNECTIONS = 800
GLOBAL_MAX_CONNECTIONS = 800
# At least, move these to context object for now
class _Resolver(gevent.resolver_thread.Resolver):
'''
See gevent.resolver_thread module. This is a way to avoid thread
dispatch for getaddrinfo called on (ip, port) tuples, since that is
such a common case and the thread dispatch seems to occassionally go
off the rails in high-load environments like stage2.
'''
def getaddrinfo(self, *args, **kwargs):
'''
only short-cut for one very specific case which is extremely
common in our code; don\'t worry about short-cutting the thread
dispatch for all possible cases
'''
if len(args) == 2 and isinstance(args[1], (int, long)):
try:
socket.inet_aton(args[0])
except socket.error:
pass
else: # args is of form (ip_string, integer)
return socket.getaddrinfo(*args)
return super(_Resolver, self).getaddrinfo(*args, **kwargs)
[docs]class ServerModelDirectory(dict):
def __missing__(self, key):
self[key] = ServerModel(key)
return self[key]
[docs]class ServerModel(object):
'''
This class represents an estimate of the state of a given "server".
"Server" is defined here by whatever accepts the socket
connections, which in practice may be an entire pool of server
machines/VMS, each of which has multiple worker thread/procs.
For example, estimate how many connections are currently open
(note: only an estimate, since the exact server-side state of the
sockets is unknown)
'''
def __init__(self, address):
self.last_error = 0
self.active_connections = weakref.WeakKeyDictionary()
self.address = address
[docs] def sock_in_use(self, sock):
self.active_connections[sock] = time.time()
def __repr__(self):
if self.last_error:
dt = datetime.datetime.fromtimestamp(int(self.last_error))
last_error = dt.strftime('%Y-%m-%d %H:%M:%S')
else:
last_error = "(None)"
return "<ServerModel {0} last_error={1} nconns={2}>".format(
repr(self.address), last_error, len(self.active_connections))
[docs]class MonitoredSocket(object):
'''
A socket proxy which allows socket lifetime to be monitored.
'''
def __init__(self, sock, registry, protected,
name=None, type=None, state=None):
self._msock = sock
self._registry = registry # TODO: better name for this
self._spawned = time.time()
self._protected = protected
self._type = type
# alias some functions through for improved performance
# (__getattr__ is pretty slow compared to normal attribute access)
self.name = name
self.state = state
[docs] def send(self, data, flags=0):
ret = self._msock.send(data, flags)
context.get_context().store_network_data(
(self.name, self._msock.getpeername()),
self.fileno(), "OUT", data)
return ret
[docs] def sendall(self, data, flags=0):
ret = self._msock.sendall(data, flags)
context.get_context().store_network_data(
(self.name, self._msock.getpeername()),
self.fileno(), "OUT", data)
return ret
[docs] def recv(self, bufsize, flags=0):
data = self._msock.recv(bufsize, flags)
context.get_context().store_network_data(
(self.name, self._msock.getpeername()),
self.fileno(), "IN", data)
return data
[docs] def close(self):
if self in self._registry:
del self._registry[self]
if self.state:
self.state.transition('closed')
return self._msock.close()
[docs] def shutdown(self, how): # not going to bother tracking half-open sockets
if self in self._registry: # (unlikely they will ever be used)
del self._registry[self]
return self._msock.shutdown(how)
def __repr__(self):
return "<MonitoredSocket " + repr(self._msock) + ">"
def __getattr__(self, attr):
return getattr(self._msock, attr)
[docs]class AddressGroup(object):
'''
An address group represents the set of addresses known by a
specific name to a client at runtime. That is, in a specific
environment (stage, live, etc), an address group represents the
set of <ip, port> pairs to try.
An address group consists of tiers. Each tier should be fully
exhausted before moving on to the next; tiers are "fallbacks". A
tier consists of prioritized addresses. Within a tier, the
addresses should be tried in a priority weighted random order.
The simplest way to use an address group is just to iterate over
it, and try each address in the order returned.
tiers: [ [(weight, (ip, port)), (weight, (ip, port)) ... ] ... ]
'''
def __init__(self, tiers):
if not any(tiers):
raise ValueError("no addresses provided for address group")
self.tiers = tiers
[docs] def connect_ordering(self):
plist = []
for tier in self.tiers:
# Kodos says: "if you can think of a simpler way of
# achieving a weighted random ordering, I'd like to hear
# it" (http://en.wikipedia.org/wiki/Kang_and_Kodos)
tlist = [(random.random() * e[0], e[1]) for e in tier]
tlist.sort()
plist.extend([e[1] for e in tlist])
return plist
def __iter__(self):
return iter(self.connect_ordering())
def __repr__(self):
return "<AddressGroup " + repr(self.tiers) + ">"
[docs]class AddressGroupMap(dict):
'''
For dev mode, will lazily pull in additional addresses.
'''
def __missing__(self, key):
ctx = context.get_context()
if ctx.stage_ip and ctx.topos:
newkey = None
for k in (key, key + "_r1", key + "_ca", key + "_r1_ca"):
if k in ctx.topos.apps:
newkey = k
break
if newkey is not None:
# TODO: maybe do r1 / r2 fallback; however, given this
# is stage only that use case is pretty slim
ports = [int(ctx.topos.get_port(newkey))]
val = AddressGroup(([(1, (ctx.stage_ip, p)) for p in ports],))
self.__dict__.setdefault("warnings", {})
self.setdefault("inferred_addresses", []).append((key, val))
self[key] = val
if key != newkey:
self.warnings["inferred_addresses"].append((newkey, val))
self[newkey] = val
return val
self.__dict__.setdefault("errors", {})
self.errors.setdefault("unknown_addresses", set()).add(key)
ctx.intervals["error.address.missing." + repr(key)].tick()
ctx.intervals["error.address.missing"].tick()
raise KeyError("unknown address requested " + repr(key))
_ADDRESS_SUFFIXES = ["_r" + str(i) for i in range(10)]
_ADDRESS_SUFFIXES = ("_ca",) + tuple(["_r" + str(i) for i in range(10)])
[docs]class MarkedDownError(socket.error):
pass
[docs]class OutOfSockets(socket.error):
pass
[docs]class NameNotFound(socket.error):
pass
[docs]class MultiConnectFailure(socket.error):
pass