summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatt Martz <matt@sivel.net>2017-05-02 10:56:31 -0500
committerMatt Martz <matt@sivel.net>2017-11-23 10:14:35 -0600
commit10b3b09f02bfd9636af4a10b84abcd3c26035949 (patch)
treeac73b7c9706e9d901eeaf624a1d7694497452f6d
parent20e5d12a5caa015653564c31f7845f1493148bed (diff)
Don't override socket.socket for binding, eliminiate globals SOURCE and USER_AGENT
-rwxr-xr-xspeedtest.py298
1 files changed, 237 insertions, 61 deletions
diff --git a/speedtest.py b/speedtest.py
index cb4a374..115b2f8 100755
--- a/speedtest.py
+++ b/speedtest.py
@@ -51,14 +51,10 @@ class FakeShutdownEvent(object):
# Some global variables we use
-USER_AGENT = None
-SOURCE = None
SHUTDOWN_EVENT = FakeShutdownEvent()
SCHEME = 'http'
DEBUG = False
-
-# Used for bound_interface
-SOCKET_SOCKET = socket.socket
+_GLOBAL_DEFAULT_TIMEOUT = object()
# Begin import game to handle Python 2 and Python 3
try:
@@ -79,9 +75,15 @@ except ImportError:
ET = None
try:
- from urllib2 import urlopen, Request, HTTPError, URLError
+ from urllib2 import (urlopen, Request, HTTPError, URLError,
+ AbstractHTTPHandler, ProxyHandler,
+ HTTPDefaultErrorHandler, HTTPRedirectHandler,
+ HTTPErrorProcessor, OpenerDirector)
except ImportError:
- from urllib.request import urlopen, Request, HTTPError, URLError
+ from urllib.request import (urlopen, Request, HTTPError, URLError,
+ AbstractHTTPHandler, ProxyHandler,
+ HTTPDefaultErrorHandler, HTTPRedirectHandler,
+ HTTPErrorProcessor, OpenerDirector)
try:
from httplib import HTTPConnection
@@ -320,6 +322,165 @@ class SpeedtestBestServerFailure(SpeedtestException):
"""Unable to determine best server"""
+def create_connection(address, timeout=_GLOBAL_DEFAULT_TIMEOUT,
+ source_address=None):
+ """Connect to *address* and return the socket object.
+
+ Convenience function. Connect to *address* (a 2-tuple ``(host,
+ port)``) and return the socket object. Passing the optional
+ *timeout* parameter will set the timeout on the socket instance
+ before attempting to connect. If no *timeout* is supplied, the
+ global default timeout setting returned by :func:`getdefaulttimeout`
+ is used. If *source_address* is set it must be a tuple of (host, port)
+ for the socket to bind as a source address before making the connection.
+ An host of '' or port 0 tells the OS to use the default.
+ """
+
+ host, port = address
+ err = None
+ for res in socket.getaddrinfo(host, port, 0, socket.SOCK_STREAM):
+ af, socktype, proto, canonname, sa = res
+ sock = None
+ try:
+ sock = socket.socket(af, socktype, proto)
+ if timeout is not _GLOBAL_DEFAULT_TIMEOUT:
+ sock.settimeout(float(timeout))
+ if source_address:
+ sock.bind(source_address)
+ sock.connect(sa)
+ return sock
+
+ except socket.error:
+ err = get_exception()
+ if sock is not None:
+ sock.close()
+
+ if err is not None:
+ raise err
+ else:
+ raise socket.error("getaddrinfo returns an empty list")
+
+
+class SpeedtestHTTPConnection(HTTPConnection):
+ def __init__(self, *args, **kwargs):
+ source_address = kwargs.pop('source_address', None)
+ context = kwargs.pop('context', None)
+ timeout = kwargs.pop('timeout', 10)
+
+ HTTPConnection.__init__(self, *args, **kwargs)
+
+ self.source_address = source_address
+ self._context = context
+ self.timeout = timeout
+
+ try:
+ self._create_connection = socket.create_connection
+ except AttributeError:
+ self._create_connection = create_connection
+
+ def connect(self):
+ """Connect to the host and port specified in __init__."""
+ self.sock = self._create_connection(
+ (self.host, self.port),
+ self.timeout,
+ self.source_address
+ )
+
+
+if HTTPSConnection:
+ class SpeedtestHTTPSConnection(HTTPSConnection,
+ SpeedtestHTTPConnection):
+ def connect(self):
+ "Connect to a host on a given (SSL) port."
+
+ SpeedtestHTTPConnection.connect(self)
+
+ kwargs = {}
+ if hasattr(ssl, 'SSLContext'):
+ kwargs['server_hostname'] = self.host
+
+ self.sock = self._context.wrap_socket(self.sock, **kwargs)
+
+
+def _build_connection(connection, source_address, timeout, context=None):
+ def inner(host, **kwargs):
+ kwargs.update({
+ 'source_address': source_address,
+ 'timeout': timeout
+ })
+ if context:
+ kwargs['context'] = context
+ return connection(host, **kwargs)
+ return inner
+
+
+class SpeedtestHTTPHandler(AbstractHTTPHandler):
+ def __init__(self, debuglevel=0, source_address=None, timeout=10):
+ AbstractHTTPHandler.__init__(self, debuglevel)
+ self.source_address = source_address
+ self.timeout = timeout
+
+ def http_open(self, req):
+ return self.do_open(
+ _build_connection(
+ SpeedtestHTTPConnection,
+ self.source_address,
+ self.timeout
+ ),
+ req
+ )
+
+ http_request = AbstractHTTPHandler.do_request_
+
+
+class SpeedtestHTTPSHandler(AbstractHTTPHandler):
+ def __init__(self, debuglevel=0, context=None, source_address=None,
+ timeout=10):
+ AbstractHTTPHandler.__init__(self, debuglevel)
+ self._context = context
+ self.source_address = source_address
+ self.timeout = timeout
+
+ def https_open(self, req):
+ return self.do_open(
+ _build_connection(
+ SpeedtestHTTPSConnection,
+ self.source_address,
+ timeout,
+ context=self._context,
+ ),
+ req
+ )
+
+ https_request = AbstractHTTPHandler.do_request_
+
+
+def build_opener(source_address=None, timeout=10):
+ if source_address:
+ source_address_tuple = (source_address, 0)
+ else:
+ source_address_tuple = None
+
+ handlers = [
+ ProxyHandler(),
+ SpeedtestHTTPHandler(source_address=source_address_tuple,
+ timeout=timeout),
+ SpeedtestHTTPSHandler(source_address=source_address_tuple,
+ timeout=timeout),
+ HTTPDefaultErrorHandler(),
+ HTTPRedirectHandler(),
+ HTTPErrorProcessor()
+ ]
+
+ opener = OpenerDirector()
+ opener.addheaders = [('User-agent', build_user_agent())]
+
+ for handler in handlers:
+ opener.add_handler(handler)
+
+ return opener
+
+
class GzipDecodedResponse(GZIP_BASE):
"""A file-like object to decode a response encoded with the gzip
method, as described in RFC 1952.
@@ -357,14 +518,6 @@ def get_exception():
return sys.exc_info()[1]
-def bound_socket(*args, **kwargs):
- """Bind socket to a specified source IP address"""
-
- sock = SOCKET_SOCKET(*args, **kwargs)
- sock.bind((SOURCE, 0))
- return sock
-
-
def distance(origin, destination):
"""Determine distance between 2 sets of [lat,lon] in km"""
@@ -387,10 +540,6 @@ def distance(origin, destination):
def build_user_agent():
"""Build a Mozilla/5.0 compatible User-Agent string"""
- global USER_AGENT
- if USER_AGENT:
- return USER_AGENT
-
ua_tuple = (
'Mozilla/5.0',
'(%s; U; %s; en-us)' % (platform.system(), platform.architecture()[0]),
@@ -398,9 +547,9 @@ def build_user_agent():
'(KHTML, like Gecko)',
'speedtest-cli/%s' % __version__
)
- USER_AGENT = ' '.join(ua_tuple)
- printer(USER_AGENT, debug=True)
- return USER_AGENT
+ user_agent = ' '.join(ua_tuple)
+ printer(user_agent, debug=True)
+ return user_agent
def build_request(url, data=None, headers=None, bump=''):
@@ -410,9 +559,6 @@ def build_request(url, data=None, headers=None, bump=''):
"""
- if not USER_AGENT:
- build_user_agent()
-
if not headers:
headers = {}
@@ -432,7 +578,6 @@ def build_request(url, data=None, headers=None, bump=''):
bump)
headers.update({
- 'User-Agent': USER_AGENT,
'Cache-Control': 'no-cache',
})
@@ -442,14 +587,19 @@ def build_request(url, data=None, headers=None, bump=''):
return Request(final_url, data=data, headers=headers)
-def catch_request(request):
+def catch_request(request, opener=None):
"""Helper function to catch common exceptions encountered when
establishing a connection with a HTTP/HTTPS request
"""
+ if opener:
+ _open = opener.open
+ else:
+ _open = urlopen
+
try:
- uh = urlopen(request)
+ uh = _open(request)
return uh, False
except HTTP_ERRORS:
e = get_exception()
@@ -505,18 +655,22 @@ def do_nothing(*args, **kwargs):
class HTTPDownloader(threading.Thread):
"""Thread class for retrieving a URL"""
- def __init__(self, i, request, start, timeout):
+ def __init__(self, i, request, start, timeout, opener=None):
threading.Thread.__init__(self)
self.request = request
self.result = [0]
self.starttime = start
self.timeout = timeout
self.i = i
+ if opener:
+ self._opener = opener.open
+ else:
+ self._opener = urlopen
def run(self):
try:
if (timeit.default_timer() - self.starttime) <= self.timeout:
- f = urlopen(self.request)
+ f = self._opener(self.request)
while (not SHUTDOWN_EVENT.isSet() and
(timeit.default_timer() - self.starttime) <=
self.timeout):
@@ -574,7 +728,7 @@ class HTTPUploaderData(object):
class HTTPUploader(threading.Thread):
"""Thread class for putting a URL"""
- def __init__(self, i, request, start, size, timeout):
+ def __init__(self, i, request, start, size, timeout, opener=None):
threading.Thread.__init__(self)
self.request = request
self.request.data.start = self.starttime = start
@@ -583,20 +737,25 @@ class HTTPUploader(threading.Thread):
self.timeout = timeout
self.i = i
+ if opener:
+ self._opener = opener.open
+ else:
+ self._opener = urlopen
+
def run(self):
request = self.request
try:
if ((timeit.default_timer() - self.starttime) <= self.timeout and
not SHUTDOWN_EVENT.isSet()):
try:
- f = urlopen(request)
+ f = self._opener(request)
except TypeError:
# PY24 expects a string or buffer
# This also causes issues with Ctrl-C, but we will concede
# for the moment that Ctrl-C on PY24 isn't immediate
request = build_request(self.request.get_full_url(),
data=request.data.read(self.size))
- f = urlopen(request)
+ f = self._opener(request)
f.read(11)
f.close()
self.result = sum(self.request.data.total)
@@ -619,7 +778,7 @@ class SpeedtestResults(object):
to get a share results image link.
"""
- def __init__(self, download=0, upload=0, ping=0, server=None):
+ def __init__(self, download=0, upload=0, ping=0, server=None, opener=None):
self.download = download
self.upload = upload
self.ping = ping
@@ -632,6 +791,11 @@ class SpeedtestResults(object):
self.bytes_received = 0
self.bytes_sent = 0
+ if opener:
+ self._opener = opener
+ else:
+ self._opener = build_opener()
+
def __repr__(self):
return repr(self.dict())
@@ -674,7 +838,7 @@ class SpeedtestResults(object):
request = build_request('://www.speedtest.net/api/api.php',
data='&'.join(api_data).encode(),
headers=headers)
- f, e = catch_request(request)
+ f, e = catch_request(request, opener=_self.opener)
if e:
raise ShareResultsConnectFailure(e)
@@ -738,8 +902,13 @@ class SpeedtestResults(object):
class Speedtest(object):
"""Class for performing standard speedtest.net testing operations"""
- def __init__(self, config=None):
+ def __init__(self, config=None, source_address=None, timeout=10):
self.config = {}
+
+ self._source_address = source_address
+ self._timeout = timeout
+ self._opener = build_opener(source_address, timeout)
+
self.get_config()
if config is not None:
self.config.update(config)
@@ -748,7 +917,7 @@ class Speedtest(object):
self.closest = []
self.best = {}
- self.results = SpeedtestResults()
+ self.results = SpeedtestResults(opener=self._opener)
def get_config(self):
"""Download the speedtest.net configuration and return only the data
@@ -760,7 +929,7 @@ class Speedtest(object):
headers['Accept-Encoding'] = 'gzip'
request = build_request('://www.speedtest.net/speedtest-config.php',
headers=headers)
- uh, e = catch_request(request)
+ uh, e = catch_request(request, opener=self._opener)
if e:
raise ConfigRetrievalError(e)
configxml = []
@@ -877,7 +1046,7 @@ class Speedtest(object):
(url,
self.config['threads']['download']),
headers=headers)
- uh, e = catch_request(request)
+ uh, e = catch_request(request, opener=self._opener)
if e:
errors.append('%s' % e)
raise ServersRetrievalError()
@@ -960,7 +1129,7 @@ class Speedtest(object):
url = server
request = build_request(url)
- uh, e = catch_request(request)
+ uh, e = catch_request(request, opener=self._opener)
if e:
raise SpeedtestMiniConnectFailure('Failed to connect to %s' %
server)
@@ -973,7 +1142,9 @@ class Speedtest(object):
if not extension:
for ext in ['php', 'asp', 'aspx', 'jsp']:
try:
- f = urlopen('%s/speedtest/upload.%s' % (url, ext))
+ f = self._opener.open(
+ '%s/speedtest/upload.%s' % (url, ext)
+ )
except:
pass
else:
@@ -1028,6 +1199,13 @@ class Speedtest(object):
servers = self.get_closest_servers()
servers = self.closest
+ if self._source_address:
+ source_address_tuple = (self._source_address, 0)
+ else:
+ source_address_tuple = None
+
+ user_agent = build_user_agent()
+
results = {}
for server in servers:
cum = []
@@ -1037,10 +1215,16 @@ class Speedtest(object):
for _ in range(0, 3):
try:
if urlparts[0] == 'https':
- h = HTTPSConnection(urlparts[1])
+ h = SpeedtestHTTPSConnection(
+ urlparts[1],
+ source_address=source_address_tuple
+ )
else:
- h = HTTPConnection(urlparts[1])
- headers = {'User-Agent': USER_AGENT}
+ h = SpeedtestHTTPConnection(
+ urlparts[1],
+ source_address=source_address_tuple
+ )
+ headers = {'User-Agent': user_agent}
start = timeit.default_timer()
h.request("GET", urlparts[2], headers=headers)
r = h.getresponse()
@@ -1093,7 +1277,8 @@ class Speedtest(object):
def producer(q, requests, request_count):
for i, request in enumerate(requests):
thread = HTTPDownloader(i, request, start,
- self.config['length']['download'])
+ self.config['length']['download'],
+ opener=self._opener)
thread.start()
q.put(thread, True)
callback(i, request_count, start=True)
@@ -1159,7 +1344,8 @@ class Speedtest(object):
def producer(q, requests, request_count):
for i, request in enumerate(requests[:request_count]):
thread = HTTPUploader(i, request[0], start, request[1],
- self.config['length']['upload'])
+ self.config['length']['upload'],
+ opener=self._opener)
thread.start()
q.put(thread, True)
callback(i, request_count, start=True)
@@ -1338,7 +1524,7 @@ def printer(string, quiet=False, debug=False, **kwargs):
def shell():
"""Run the full speedtest.net test"""
- global SHUTDOWN_EVENT, SOURCE, SCHEME, DEBUG
+ global SHUTDOWN_EVENT, SCHEME, DEBUG
SHUTDOWN_EVENT = threading.Event()
signal.signal(signal.SIGINT, ctrl_c)
@@ -1361,13 +1547,6 @@ def shell():
validate_optional_args(args)
- socket.setdefaulttimeout(args.timeout)
-
- # If specified bind to a specific IP address
- if args.source:
- SOURCE = args.source
- socket.socket = bound_socket
-
if args.secure:
SCHEME = 'https'
@@ -1377,9 +1556,6 @@ def shell():
if debug:
DEBUG = True
- # Pre-cache the user agent string
- build_user_agent()
-
if args.simple or args.csv or args.json:
quiet = True
else:
@@ -1398,15 +1574,15 @@ def shell():
printer('Retrieving speedtest.net configuration...', quiet)
try:
- speedtest = Speedtest()
- except (ConfigRetrievalError, HTTP_ERRORS):
+ speedtest = Speedtest(source_address=args.source, timeout=args.timeout)
+ except (ConfigRetrievalError,) + HTTP_ERRORS:
printer('Cannot retrieve speedtest configuration')
raise SpeedtestCLIError(get_exception())
if args.list:
try:
speedtest.get_servers()
- except (ServersRetrievalError, HTTP_ERRORS):
+ except (ServersRetrievalError,) + HTTP_ERRORS:
print_('Cannot retrieve speedtest server list')
raise SpeedtestCLIError(get_exception())
@@ -1436,7 +1612,7 @@ def shell():
speedtest.get_servers(servers)
except NoMatchedServers:
raise SpeedtestCLIError('No matched servers: %s' % args.server)
- except (ServersRetrievalError, HTTP_ERRORS):
+ except (ServersRetrievalError,) + HTTP_ERRORS:
print_('Cannot retrieve speedtest server list')
raise SpeedtestCLIError(get_exception())
except InvalidServerIDType: