summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorBrian May <brian@linuxpenguins.xyz>2015-12-06 11:02:31 +1100
committerBrian May <brian@linuxpenguins.xyz>2015-12-06 11:02:31 +1100
commitbd97506f7d6cf52365e0d93311c95c94decc81c7 (patch)
tree1e75c70035fc60e6d093ae08e00160755662019d
parent53c07f7d90696712aeffde556eddb50cd163b172 (diff)
Fixup firewall tests.
-rw-r--r--sshuttle/tests/test_firewall.py65
1 files changed, 30 insertions, 35 deletions
diff --git a/sshuttle/tests/test_firewall.py b/sshuttle/tests/test_firewall.py
index 540a5bf..10734d9 100644
--- a/sshuttle/tests/test_firewall.py
+++ b/sshuttle/tests/test_firewall.py
@@ -1,9 +1,5 @@
from mock import Mock, patch, call
import io
-import os
-import os.path
-import shutil
-import filecmp
import sshuttle.firewall
@@ -19,27 +15,27 @@ NSLIST
10,2404:6800:4004:80c::33
PORTS 1024,1025,1026,1027
GO 1
+HOST 1.2.3.3,existing
""")
stdout = Mock()
return stdin, stdout
-@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts')
-@patch('sshuttle.firewall.hostmap', new={
- 'myhost': '1.2.3.4',
- 'myotherhost': '1.2.3.5',
-})
-def test_rewrite_etc_hosts():
- if not os.path.isdir("tmp"):
- os.mkdir("tmp")
+def test_rewrite_etc_hosts(tmpdir):
+ orig_hosts = tmpdir.join("hosts.orig")
+ orig_hosts.write("1.2.3.3 existing\n")
- with open("tmp/hosts.orig", "w") as f:
- f.write("1.2.3.3 existing\n")
+ new_hosts = tmpdir.join("hosts")
+ orig_hosts.copy(new_hosts)
- shutil.copyfile("tmp/hosts.orig", "tmp/hosts")
+ hostmap = {
+ 'myhost': '1.2.3.4',
+ 'myotherhost': '1.2.3.5',
+ }
+ with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)):
+ sshuttle.firewall.rewrite_etc_hosts(hostmap, 10)
- sshuttle.firewall.rewrite_etc_hosts(10)
- with open("tmp/hosts") as f:
+ with new_hosts.open() as f:
line = f.readline()
s = line.split()
assert s == ['1.2.3.3', 'existing']
@@ -57,39 +53,37 @@ def test_rewrite_etc_hosts():
line = f.readline()
assert line == ""
- sshuttle.firewall.restore_etc_hosts(10)
- assert filecmp.cmp("tmp/hosts.orig", "tmp/hosts", shallow=False) is True
+ with patch('sshuttle.firewall.HOSTSFILE', new=str(new_hosts)):
+ sshuttle.firewall.restore_etc_hosts(10)
+ assert orig_hosts.computehash() == new_hosts.computehash()
-@patch('sshuttle.firewall.HOSTSFILE', new='tmp/hosts')
+@patch('sshuttle.firewall.rewrite_etc_hosts')
@patch('sshuttle.firewall.setup_daemon')
@patch('sshuttle.firewall.get_method')
-def test_main(mock_get_method, mock_setup_daemon):
+def test_main(mock_get_method, mock_setup_daemon, mock_rewrite_etc_hosts):
stdin, stdout = setup_daemon()
mock_setup_daemon.return_value = stdin, stdout
- if not os.path.isdir("tmp"):
- os.mkdir("tmp")
+ mock_get_method("not_auto").name = "test"
+ mock_get_method.reset_mock()
- sshuttle.firewall.main("test", False)
+ sshuttle.firewall.main("not_auto", False)
- with open("tmp/hosts") as f:
- line = f.readline()
- s = line.split()
- assert s == ['1.2.3.3', 'existing']
-
- line = f.readline()
- assert line == ""
+ assert mock_rewrite_etc_hosts.mock_calls == [
+ call({'1.2.3.3': 'existing'}, 1024),
+ call({}, 1024),
+ ]
- stdout.mock_calls == [
+ assert stdout.mock_calls == [
call.write('READY test\n'),
call.flush(),
call.write('STARTED\n'),
call.flush()
]
- mock_setup_daemon.mock_calls == [call()]
- mock_get_method.mock_calls == [
- call('test'),
+ assert mock_setup_daemon.mock_calls == [call()]
+ assert mock_get_method.mock_calls == [
+ call('not_auto'),
call().setup_firewall(
1024, 1026,
[(10, u'2404:6800:4004:80c::33')],
@@ -104,6 +98,7 @@ def test_main(mock_get_method, mock_setup_daemon):
[(2, 24, False, u'1.2.3.0'), (2, 32, True, u'1.2.3.66')],
True),
call().setup_firewall()(),
+ call().setup_firewall()(),
call().setup_firewall(1024, 0, [], 10, [], True),
call().setup_firewall(1025, 0, [], 2, [], True),
]