diff --git a/ssm/agents.py b/ssm/agents.py index 4e71ae0a..f26fcd4c 100644 --- a/ssm/agents.py +++ b/ssm/agents.py @@ -42,7 +42,9 @@ except ImportError: # ImportError is raised when Ssm2 initialised if AMS is requested but lib # not installed. - AmsConnectionException = None + class AmsConnectionException(Exception): + """Placeholder exception if argo_ams_library not used.""" + pass from ssm import set_up_logging, LOG_BREAK from ssm.ssm2 import Ssm2, Ssm2Exception @@ -201,6 +203,7 @@ def run_sender(protocol, brokers, project, token, cp, log): try: verify_server_cert = cp.getboolean('certificates', 'verify_server_cert') except ConfigParser.NoOptionError: + # If option not set, resort to value of verify_server_cert set above. pass except ConfigParser.NoOptionError: log.info('No server certificate supplied. Will not encrypt messages.') diff --git a/ssm/crypto.py b/ssm/crypto.py index e9380396..13badd8b 100644 --- a/ssm/crypto.py +++ b/ssm/crypto.py @@ -41,9 +41,8 @@ class CryptoException(Exception): def _from_file(filename): """Read entire file into string. Convenience function.""" - f = open(filename, 'r') - s = f.read() - f.close() + with open(filename, 'r') as f: + s = f.read() return s diff --git a/ssm/ssm2.py b/ssm/ssm2.py index 8f0a6ed4..ae1ab665 100644 --- a/ssm/ssm2.py +++ b/ssm/ssm2.py @@ -291,8 +291,6 @@ def _handle_msg(self, text): warning = 'Empty text passed to _handle_msg.' log.warning(warning) return None, None, warning -# if not text.startswith('MIME-Version: 1.0'): -# raise Ssm2Exception('Not a valid message.') # encrypted - this could be nicer if 'application/pkcs7-mime' in text or 'application/x-pkcs7-mime' in text: @@ -655,10 +653,9 @@ def startup(self): """Create the pidfile then start the connection.""" if self._pidfile is not None: try: - f = open(self._pidfile, 'w') - f.write(str(os.getpid())) - f.write('\n') - f.close() + with open(self._pidfile, 'w') as f: + f.write(str(os.getpid())) + f.write('\n') except IOError as e: log.warning('Failed to create pidfile %s: %s', self._pidfile, e) diff --git a/test/test_agents.py b/test/test_agents.py index 6837b1d1..7b33f9ee 100644 --- a/test/test_agents.py +++ b/test/test_agents.py @@ -39,9 +39,8 @@ def test_get_good_dns(self): output = ['/C=UK/O=eScience/OU=CLRC/L=RAL/CN=scarfrap.esc.rl.ac.uk', '/C=UK/O=eScience/OU=CLRC/L=RAL/CN=uas-dev.esc.rl.ac.uk', '/C=UK/ST=RAL/L=A City/O=eScene/OU=CC/CN=cld.grid.rl.uk'] - f = open(self.tf_path, 'w') - f.write(dn_text) - f.close() + with open(self.tf_path, 'w') as f: + f.write(dn_text) self.assertEqual(ssm.agents.get_dns(self.tf_path, self.mock_log), output) def test_get_iffy_dns(self): @@ -54,9 +53,8 @@ def test_get_iffy_dns(self): # Another bad DN C=UK/O=eScene/OU=CLRC/L=RAL/CN=apel-dev.esc.rl.ac.uk """) - f = open(self.tf_path, 'w') - f.write(dn_text) - f.close() + with open(self.tf_path, 'w') as f: + f.write(dn_text) ssm.agents.get_dns(self.tf_path, self.mock_log) self.assertEqual(self.mock_log.warning.call_count, 2) diff --git a/test/test_brokers.py b/test/test_brokers.py index a29bbfe3..93dc776a 100644 --- a/test/test_brokers.py +++ b/test/test_brokers.py @@ -34,11 +34,7 @@ def test_parse_stomp_url(self): http_url = 'http://not.a.stomp.url:8080' - try: - brokers.parse_stomp_url(http_url) - self.fail('Parsed a URL which was not STOMP.') - except ValueError: - pass + self.assertRaises(ValueError, brokers.parse_stomp_url, http_url) self.assertRaises(ValueError, brokers.parse_stomp_url, 'stomp://invalid.port.number:abc') @@ -126,6 +122,9 @@ def _mocked_search(*args, **kwargs): 'rgo.grnet.gr_msg.broker.stomp_175215210,Mds-Vo-name=HG-06-EKT' ',Mds-Vo-name=local,o=grid', {'GlueServiceDataValue': ['PROD']} )] + else: + # This will tell mock to use the normal return value + return mock.DEFAULT if __name__ == '__main__': diff --git a/test/test_crypto.py b/test/test_crypto.py index 2f613ca7..db13fdbb 100644 --- a/test/test_crypto.py +++ b/test/test_crypto.py @@ -106,11 +106,12 @@ def test_sign(self): # Indirect testing, using the verify_message() method retrieved_msg, retrieved_dn = verify(signed, TEST_CA_DIR, False) - if not retrieved_dn == TEST_CERT_DN: - self.fail("The DN of the verified message didn't match the cert.") + self.assertEqual( + retrieved_dn, TEST_CERT_DN, + "The DN of the verified message didn't match the cert.") - if not retrieved_msg == MSG: - self.fail("The verified message didn't match the original.") + self.assertEqual(retrieved_msg, MSG, + "The verified message didn't match the original.") def test_verify(self): @@ -147,43 +148,31 @@ def test_verify(self): retrieved_msg, retrieved_dn = verify(signed_msg, TEST_CA_DIR, False) - if not retrieved_dn == TEST_CERT_DN: - self.fail("The DN of the verified message didn't match the cert.") + self.assertEqual( + retrieved_dn, TEST_CERT_DN, + "The DN of the verified message didn't match the cert.") - if not retrieved_msg.strip() == MSG: - self.fail("The verified messge didn't match the original.") + self.assertEqual( + retrieved_msg.strip(), MSG, + "The verified messge didn't match the original.") retrieved_msg2, retrieved_dn2 = verify(signed_msg2, TEST_CA_DIR, False) - if not retrieved_dn2 == TEST_CERT_DN: - print(retrieved_dn2) - print(TEST_CERT_DN) - self.fail("The DN of the verified message didn't match the cert.") + self.assertEqual( + retrieved_dn2, TEST_CERT_DN, + "The DN of the verified message didn't match the cert.") - if not retrieved_msg2.strip() == MSG2: - print(retrieved_msg2) - print(MSG2) - self.fail("The verified messge didn't match the original.") + self.assertEqual( + retrieved_msg2.strip(), MSG2, + "The verified messge didn't match the original.") # Try empty string - try: - verify('', TEST_CA_DIR, False) - except CryptoException: - pass + self.assertRaises(CryptoException, verify, '', TEST_CA_DIR, False) # Try rubbish - try: - verify('Bibbly bobbly', TEST_CA_DIR, False) - except CryptoException: - pass + self.assertRaises(CryptoException, verify, 'Bibbly bobbly', TEST_CA_DIR, False) # Try None arguments - try: - verify('Bibbly bobbly', None, False) - except CryptoException: - pass - try: - verify(None, 'not a path', False) - except CryptoException: - pass + self.assertRaises(CryptoException, verify, 'Bibbly bobbly', None, False) + self.assertRaises(CryptoException, verify, None, 'not a path', False) def test_get_certificate_subject(self): ''' @@ -198,17 +187,9 @@ def test_get_certificate_subject(self): if not dn == TEST_CERT_DN: self.fail("Didn't retrieve correct DN from cert.") - try: - subj = get_certificate_subject('Rubbish') - self.fail('Returned %s as subject from empty string.' % subj) - except CryptoException: - pass + self.assertRaises(CryptoException, get_certificate_subject, 'Rubbish') - try: - subj = get_certificate_subject('') - self.fail('Returned %s as subject from empty string.' % subj) - except CryptoException: - pass + self.assertRaises(CryptoException, get_certificate_subject, '') def test_get_signer_cert(self): ''' @@ -244,10 +225,7 @@ def test_encrypt(self): self.fail("Encrypted message wasn't decrypted successfully.") # invalid cipher - try: - encrypted = encrypt(MSG, TEST_CERT_FILE, 'aes1024') - except CryptoException: - pass + self.assertRaises(CryptoException, encrypt, MSG, TEST_CERT_FILE, 'aes1024') def test_decrypt(self): @@ -287,11 +265,7 @@ def test_verify_cert(self): self.fail('The self-signed certificate should not be verified ' + 'if CRLs are checked.') - try: - if verify_cert(None, TEST_CA_DIR, False): - self.fail('Verified None rather than certificate string.') - except CryptoException: - pass + self.assertRaises(CryptoException, verify_cert, None, TEST_CA_DIR, False) def test_message_tampering(self): """Test that a tampered message is not accepted as valid."""