diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index 770b000fca..e77e731fca 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -937,13 +937,6 @@ def tearDown(self): self.cluster.shutdown() -def assert_startswith(s, prefix): - if not s.startswith(prefix): - raise AssertionError( - '{} does not start with {}'.format(repr(s), repr(prefix)) - ) - - class TestCluster(object): __test__ = False diff --git a/tests/integration/cqlengine/__init__.py b/tests/integration/cqlengine/__init__.py index 204dcb1253..7fae437370 100644 --- a/tests/integration/cqlengine/__init__.py +++ b/tests/integration/cqlengine/__init__.py @@ -77,12 +77,7 @@ def wrapped_function(*args, **kwargs): # DeMonkey Patch our code cassandra.cqlengine.connection.execute = original_function # Check to see if we have a pre-existing test case to work from. - if args: - test_case = args[0] - else: - test_case = unittest.TestCase("__init__") - # Check to see if the count is what you expect - test_case.assertEqual(count.get_counter(), expected, msg="Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter())) + assert count.get_counter() == expected, "Expected number of cassandra.cqlengine.connection.execute calls ({0}) doesn't match actual number invoked ({1})".format(expected, count.get_counter()) return to_return # Name of the wrapped function must match the original or unittest will error out. wrapped_function.__name__ = fn.__name__ @@ -94,5 +89,3 @@ def wrapped_function(*args, **kwargs): return wrapped_function return innerCounter - - diff --git a/tests/integration/cqlengine/base.py b/tests/integration/cqlengine/base.py index e2c02c82a3..c65554b974 100644 --- a/tests/integration/cqlengine/base.py +++ b/tests/integration/cqlengine/base.py @@ -40,15 +40,3 @@ class BaseCassEngTestCase(unittest.TestCase): def setUp(self): self.session = get_session() - - def assertHasAttr(self, obj, attr): - self.assertTrue(hasattr(obj, attr), - "{0} doesn't have attribute: {1}".format(obj, attr)) - - def assertNotHasAttr(self, obj, attr): - self.assertFalse(hasattr(obj, attr), - "{0} shouldn't have the attribute: {1}".format(obj, attr)) - - if sys.version_info > (3, 0): - def assertItemsEqual(self, first, second, msg=None): - return self.assertCountEqual(first, second, msg) diff --git a/tests/integration/cqlengine/columns/test_container_columns.py b/tests/integration/cqlengine/columns/test_container_columns.py index 4c21ff55d8..6fb2754877 100644 --- a/tests/integration/cqlengine/columns/test_container_columns.py +++ b/tests/integration/cqlengine/columns/test_container_columns.py @@ -30,6 +30,7 @@ from tests.integration.cqlengine import is_prepend_reversed from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration import greaterthancass20, CASSANDRA_VERSION +import pytest log = logging.getLogger(__name__) @@ -72,7 +73,8 @@ def tearDownClass(cls): drop_table(TestSetModel) def test_add_none_fails(self): - self.assertRaises(ValidationError, TestSetModel.create, **{'int_set': set([None])}) + with pytest.raises(ValidationError): + TestSetModel.create(int_set=set([None])) def test_empty_set_initial(self): """ @@ -91,7 +93,7 @@ def test_deleting_last_item_should_succeed(self): m.save() m = TestSetModel.get(partition=m.partition) - self.assertTrue(5 not in m.int_set) + assert 5 not in m.int_set def test_blind_deleting_last_item_should_succeed(self): m = TestSetModel.create() @@ -101,7 +103,7 @@ def test_blind_deleting_last_item_should_succeed(self): TestSetModel.objects(partition=m.partition).update(int_set=set()) m = TestSetModel.get(partition=m.partition) - self.assertTrue(5 not in m.int_set) + assert 5 not in m.int_set def test_empty_set_retrieval(self): m = TestSetModel.create() @@ -113,20 +115,21 @@ def test_io_success(self): m1 = TestSetModel.create(int_set=set((1, 2)), text_set=set(('kai', 'andreas'))) m2 = TestSetModel.get(partition=m1.partition) - self.assertIsInstance(m2.int_set, set) - self.assertIsInstance(m2.text_set, set) + assert isinstance(m2.int_set, set) + assert isinstance(m2.text_set, set) - self.assertIn(1, m2.int_set) - self.assertIn(2, m2.int_set) + assert 1 in m2.int_set + assert 2 in m2.int_set - self.assertIn('kai', m2.text_set) - self.assertIn('andreas', m2.text_set) + assert 'kai' in m2.text_set + assert 'andreas' in m2.text_set def test_type_validation(self): """ Tests that attempting to use the wrong types will raise an exception """ - self.assertRaises(ValidationError, TestSetModel.create, **{'int_set': set(('string', True)), 'text_set': set((1, 3.0))}) + with pytest.raises(ValidationError): + TestSetModel.create(int_set=set(('string', True)), text_set=set((1, 3.0))) def test_element_count_validation(self): """ @@ -142,8 +145,9 @@ def test_element_count_validation(self): del tb except OperationTimedOut: #This will happen if the host is remote - self.assertFalse(CASSANDRA_IP.startswith("127.0.0.")) - self.assertRaises(ValidationError, TestSetModel.create, **{'text_set': set(str(uuid4()) for i in range(65536))}) + assert not CASSANDRA_IP.startswith("127.0.0.") + with pytest.raises(ValidationError): + TestSetModel.create(text_set=set(str(uuid4()) for i in range(65536))) def test_partial_updates(self): """ Tests that partial udpates work as expected """ @@ -151,12 +155,12 @@ def test_partial_updates(self): m1.int_set.add(5) m1.int_set.remove(1) - self.assertEqual(m1.int_set, set((2, 3, 4, 5))) + assert m1.int_set == set((2, 3, 4, 5)) m1.save() m2 = TestSetModel.get(partition=m1.partition) - self.assertEqual(m2.int_set, set((2, 3, 4, 5))) + assert m2.int_set == set((2, 3, 4, 5)) def test_instantiation_with_column_class(self): """ @@ -164,23 +168,23 @@ def test_instantiation_with_column_class(self): and that the class is instantiated in the constructor """ column = columns.Set(columns.Text) - self.assertIsInstance(column.value_col, columns.Text) + assert isinstance(column.value_col, columns.Text) def test_instantiation_with_column_instance(self): """ Tests that columns instantiated with a column instance work properly """ column = columns.Set(columns.Text(min_length=100)) - self.assertIsInstance(column.value_col, columns.Text) + assert isinstance(column.value_col, columns.Text) def test_to_python(self): """ Tests that to_python of value column is called """ column = columns.Set(JsonTestColumn) val = set((1, 2, 3)) db_val = column.to_database(val) - self.assertEqual(db_val, set(json.dumps(v) for v in val)) + assert db_val == set(json.dumps(v) for v in val) py_val = column.to_python(db_val) - self.assertEqual(py_val, val) + assert py_val == val def test_default_empty_container_saving(self): """ tests that the default empty container is not saved if it hasn't been updated """ @@ -191,7 +195,7 @@ def test_default_empty_container_saving(self): TestSetModel.create(partition=pkey) m = TestSetModel.get(partition=pkey) - self.assertEqual(m.int_set, set((3, 4))) + assert m.int_set == set((3, 4)) class TestListModel(Model): @@ -227,23 +231,24 @@ def test_io_success(self): m1 = TestListModel.create(int_list=[1, 2], text_list=['kai', 'andreas']) m2 = TestListModel.get(partition=m1.partition) - self.assertIsInstance(m2.int_list, list) - self.assertIsInstance(m2.text_list, list) + assert isinstance(m2.int_list, list) + assert isinstance(m2.text_list, list) - self.assertEqual(len(m2.int_list), 2) - self.assertEqual(len(m2.text_list), 2) + assert len(m2.int_list) == 2 + assert len(m2.text_list) == 2 - self.assertEqual(m2.int_list[0], 1) - self.assertEqual(m2.int_list[1], 2) + assert m2.int_list[0] == 1 + assert m2.int_list[1] == 2 - self.assertEqual(m2.text_list[0], 'kai') - self.assertEqual(m2.text_list[1], 'andreas') + assert m2.text_list[0] == 'kai' + assert m2.text_list[1] == 'andreas' def test_type_validation(self): """ Tests that attempting to use the wrong types will raise an exception """ - self.assertRaises(ValidationError, TestListModel.create, **{'int_list': ['string', True], 'text_list': [1, 3.0]}) + with pytest.raises(ValidationError): + TestListModel.create(int_list=['string', True], text_list=[1, 3.0]) def test_element_count_validation(self): """ @@ -257,7 +262,8 @@ def test_element_count_validation(self): ex_type, ex, tb = sys.exc_info() log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb - self.assertRaises(ValidationError, TestListModel.create, **{'text_list': [str(uuid4()) for _ in range(65536)]}) + with pytest.raises(ValidationError): + TestListModel.create(text_list=[str(uuid4()) for _ in range(65536)]) def test_partial_updates(self): """ Tests that partial udpates work as expected """ @@ -275,7 +281,7 @@ def test_partial_updates(self): expected = full m2 = TestListModel.get(partition=m1.partition) - self.assertEqual(list(m2.int_list), expected) + assert list(m2.int_list) == expected def test_instantiation_with_column_class(self): """ @@ -283,23 +289,23 @@ def test_instantiation_with_column_class(self): and that the class is instantiated in the constructor """ column = columns.List(columns.Text) - self.assertIsInstance(column.value_col, columns.Text) + assert isinstance(column.value_col, columns.Text) def test_instantiation_with_column_instance(self): """ Tests that columns instantiated with a column instance work properly """ column = columns.List(columns.Text(min_length=100)) - self.assertIsInstance(column.value_col, columns.Text) + assert isinstance(column.value_col, columns.Text) def test_to_python(self): """ Tests that to_python of value column is called """ column = columns.List(JsonTestColumn) val = [1, 2, 3] db_val = column.to_database(val) - self.assertEqual(db_val, [json.dumps(v) for v in val]) + assert db_val == [json.dumps(v) for v in val] py_val = column.to_python(db_val) - self.assertEqual(py_val, val) + assert py_val == val def test_default_empty_container_saving(self): """ tests that the default empty container is not saved if it hasn't been updated """ @@ -310,7 +316,7 @@ def test_default_empty_container_saving(self): TestListModel.create(partition=pkey) m = TestListModel.get(partition=pkey) - self.assertEqual(m.int_list, [1, 2, 3, 4]) + assert m.int_list == [1, 2, 3, 4] def test_remove_entry_works(self): pkey = uuid4() @@ -318,7 +324,7 @@ def test_remove_entry_works(self): tmp.int_list.pop() tmp.update() tmp = TestListModel.get(partition=pkey) - self.assertEqual(tmp.int_list, [1]) + assert tmp.int_list == [1] def test_update_from_non_empty_to_empty(self): pkey = uuid4() @@ -327,11 +333,12 @@ def test_update_from_non_empty_to_empty(self): tmp.update() tmp = TestListModel.get(partition=pkey) - self.assertEqual(tmp.int_list, []) + assert tmp.int_list == [] def test_insert_none(self): pkey = uuid4() - self.assertRaises(ValidationError, TestListModel.create, **{'partition': pkey, 'int_list': [None]}) + with pytest.raises(ValidationError): + TestListModel.create(partition=pkey, int_list=[None]) def test_blind_list_updates_from_none(self): """ Tests that updates from None work as expected """ @@ -341,12 +348,12 @@ def test_blind_list_updates_from_none(self): m.save() m2 = TestListModel.get(partition=m.partition) - self.assertEqual(m2.int_list, expected) + assert m2.int_list == expected TestListModel.objects(partition=m.partition).update(int_list=[]) m3 = TestListModel.get(partition=m.partition) - self.assertEqual(m3.int_list, []) + assert m3.int_list == [] class TestMapModel(Model): @@ -373,7 +380,8 @@ def test_empty_default(self): tmp.int_map['blah'] = 1 def test_add_none_as_map_key(self): - self.assertRaises(ValidationError, TestMapModel.create, **{'int_map': {None: uuid4()}}) + with pytest.raises(ValidationError): + TestMapModel.create(int_map={None: uuid4()}) def test_empty_retrieve(self): tmp = TestMapModel.create() @@ -388,7 +396,7 @@ def test_remove_last_entry_works(self): tmp.save() tmp = TestMapModel.get(partition=tmp.partition) - self.assertTrue("blah" not in tmp.int_map) + assert "blah" not in tmp.int_map def test_io_success(self): """ Tests that a basic usage works as expected """ @@ -400,22 +408,23 @@ def test_io_success(self): text_map={'now': now, 'then': then}) m2 = TestMapModel.get(partition=m1.partition) - self.assertTrue(isinstance(m2.int_map, dict)) - self.assertTrue(isinstance(m2.text_map, dict)) + assert isinstance(m2.int_map, dict) + assert isinstance(m2.text_map, dict) - self.assertTrue(1 in m2.int_map) - self.assertTrue(2 in m2.int_map) - self.assertEqual(m2.int_map[1], k1) - self.assertEqual(m2.int_map[2], k2) + assert 1 in m2.int_map + assert 2 in m2.int_map + assert m2.int_map[1] == k1 + assert m2.int_map[2] == k2 - self.assertAlmostEqual(get_total_seconds(now - m2.text_map['now']), 0, 2) - self.assertAlmostEqual(get_total_seconds(then - m2.text_map['then']), 0, 2) + assert get_total_seconds(now - m2.text_map['now']) == pytest.approx(0, abs=1e-2) + assert get_total_seconds(then - m2.text_map['then']) == pytest.approx(0, abs=1e-2) def test_type_validation(self): """ Tests that attempting to use the wrong types will raise an exception """ - self.assertRaises(ValidationError, TestMapModel.create, **{'int_map': {'key': 2, uuid4(): 'val'}, 'text_map': {2: 5}}) + with pytest.raises(ValidationError): + TestMapModel.create(int_map={'key': 2, uuid4(): 'val'}, text_map={2: 5}) def test_element_count_validation(self): """ @@ -429,7 +438,8 @@ def test_element_count_validation(self): ex_type, ex, tb = sys.exc_info() log.warning("{0}: {1} Backtrace: {2}".format(ex_type.__name__, ex, traceback.extract_tb(tb))) del tb - self.assertRaises(ValidationError, TestMapModel.create, **{'text_map': dict((str(uuid4()), i) for i in range(65536))}) + with pytest.raises(ValidationError): + TestMapModel.create(text_map=dict((str(uuid4()), i) for i in range(65536))) def test_partial_updates(self): """ Tests that partial udpates work as expected """ @@ -449,7 +459,7 @@ def test_partial_updates(self): m1.save() m2 = TestMapModel.get(partition=m1.partition) - self.assertEqual(m2.text_map, final) + assert m2.text_map == final def test_updates_from_none(self): """ Tests that updates from None work as expected """ @@ -459,12 +469,12 @@ def test_updates_from_none(self): m.save() m2 = TestMapModel.get(partition=m.partition) - self.assertEqual(m2.int_map, expected) + assert m2.int_map == expected m2.int_map = None m2.save() m3 = TestMapModel.get(partition=m.partition) - self.assertNotEqual(m3.int_map, expected) + assert m3.int_map != expected def test_blind_updates_from_none(self): """ Tests that updates from None work as expected """ @@ -474,12 +484,12 @@ def test_blind_updates_from_none(self): m.save() m2 = TestMapModel.get(partition=m.partition) - self.assertEqual(m2.int_map, expected) + assert m2.int_map == expected TestMapModel.objects(partition=m.partition).update(int_map={}) m3 = TestMapModel.get(partition=m.partition) - self.assertNotEqual(m3.int_map, expected) + assert m3.int_map != expected def test_updates_to_none(self): """ Tests that setting the field to None works as expected """ @@ -488,7 +498,7 @@ def test_updates_to_none(self): m.save() m2 = TestMapModel.get(partition=m.partition) - self.assertEqual(m2.int_map, {}) + assert m2.int_map == {} def test_instantiation_with_column_class(self): """ @@ -496,25 +506,25 @@ def test_instantiation_with_column_class(self): and that the class is instantiated in the constructor """ column = columns.Map(columns.Text, columns.Integer) - self.assertIsInstance(column.key_col, columns.Text) - self.assertIsInstance(column.value_col, columns.Integer) + assert isinstance(column.key_col, columns.Text) + assert isinstance(column.value_col, columns.Integer) def test_instantiation_with_column_instance(self): """ Tests that columns instantiated with a column instance work properly """ column = columns.Map(columns.Text(min_length=100), columns.Integer()) - self.assertIsInstance(column.key_col, columns.Text) - self.assertIsInstance(column.value_col, columns.Integer) + assert isinstance(column.key_col, columns.Text) + assert isinstance(column.value_col, columns.Integer) def test_to_python(self): """ Tests that to_python of value column is called """ column = columns.Map(JsonTestColumn, JsonTestColumn) val = {1: 2, 3: 4, 5: 6} db_val = column.to_database(val) - self.assertEqual(db_val, dict((json.dumps(k), json.dumps(v)) for k, v in val.items())) + assert db_val == dict((json.dumps(k), json.dumps(v)) for k, v in val.items()) py_val = column.to_python(db_val) - self.assertEqual(py_val, val) + assert py_val == val def test_default_empty_container_saving(self): """ tests that the default empty container is not saved if it hasn't been updated """ @@ -526,7 +536,7 @@ def test_default_empty_container_saving(self): TestMapModel.create(partition=pkey) m = TestMapModel.get(partition=pkey) - self.assertEqual(m.int_map, tmap) + assert m.int_map == tmap class TestCamelMapModel(Model): @@ -617,13 +627,13 @@ def test_io_success(self): m1 = TestTupleModel.create(int_tuple=(1, 2, 3, 5, 6), text_tuple=('kai', 'andreas'), mixed_tuple=('first', 2, 'Third')) m2 = TestTupleModel.get(partition=m1.partition) - self.assertIsInstance(m2.int_tuple, tuple) - self.assertIsInstance(m2.text_tuple, tuple) - self.assertIsInstance(m2.mixed_tuple, tuple) + assert isinstance(m2.int_tuple, tuple) + assert isinstance(m2.text_tuple, tuple) + assert isinstance(m2.mixed_tuple, tuple) - self.assertEqual((1, 2, 3), m2.int_tuple) - self.assertEqual(('kai', 'andreas'), m2.text_tuple) - self.assertEqual(('first', 2, 'Third'), m2.mixed_tuple) + assert (1, 2, 3) == m2.int_tuple + assert ('kai', 'andreas') == m2.text_tuple + assert ('first', 2, 'Third') == m2.mixed_tuple def test_type_validation(self): """ @@ -635,9 +645,12 @@ def test_type_validation(self): @test_category object_mapper """ - self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', True), 'text_tuple': ('test', 'test'), 'mixed_tuple': ('one', 2, 'three')}) - self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', 'string'), 'text_tuple': (1, 3.0), 'mixed_tuple': ('one', 2, 'three')}) - self.assertRaises(ValidationError, TestTupleModel.create, **{'int_tuple': ('string', 'string'), 'text_tuple': ('test', 'test'), 'mixed_tuple': (1, "two", 3)}) + with pytest.raises(ValidationError): + TestTupleModel.create(int_tuple=('string', True), text_tuple=('test', 'test'), mixed_tuple=('one', 2, 'three')) + with pytest.raises(ValidationError): + TestTupleModel.create(int_tuple=('string', 'string'), text_tuple=(1, 3.0), mixed_tuple=('one', 2, 'three')) + with pytest.raises(ValidationError): + TestTupleModel.create(int_tuple=('string', 'string'), text_tuple=('test', 'test'), mixed_tuple=(1, "two", 3)) def test_instantiation_with_column_class(self): """ @@ -651,10 +664,10 @@ def test_instantiation_with_column_class(self): @test_category object_mapper """ mixed_tuple = columns.Tuple(columns.Text, columns.Integer, columns.Text, required=False) - self.assertIsInstance(mixed_tuple.types[0], columns.Text) - self.assertIsInstance(mixed_tuple.types[1], columns.Integer) - self.assertIsInstance(mixed_tuple.types[2], columns.Text) - self.assertEqual(len(mixed_tuple.types), 3) + assert isinstance(mixed_tuple.types[0], columns.Text) + assert isinstance(mixed_tuple.types[1], columns.Integer) + assert isinstance(mixed_tuple.types[2], columns.Text) + assert len(mixed_tuple.types) == 3 def test_default_empty_container_saving(self): """ @@ -673,7 +686,7 @@ def test_default_empty_container_saving(self): TestTupleModel.create(partition=pkey) m = TestTupleModel.get(partition=pkey) - self.assertEqual(m.int_tuple, (1, 2, 3)) + assert m.int_tuple == (1, 2, 3) def test_updates(self): """ @@ -693,7 +706,7 @@ def test_updates(self): m1.save() m2 = TestTupleModel.get(partition=m1.partition) - self.assertEqual(tuple(m2.int_tuple), replacement) + assert tuple(m2.int_tuple) == replacement def test_update_from_non_empty_to_empty(self): """ @@ -711,7 +724,7 @@ def test_update_from_non_empty_to_empty(self): tmp.update() tmp = TestTupleModel.get(partition=pkey) - self.assertEqual(tmp.int_tuple, (None)) + assert tmp.int_tuple == (None) def test_insert_none(self): """ @@ -725,7 +738,7 @@ def test_insert_none(self): """ pkey = uuid4() tmp = TestTupleModel.create(partition=pkey, int_tuple=(None)) - self.assertEqual((None), tmp.int_tuple) + assert (None) == tmp.int_tuple def test_blind_tuple_updates_from_none(self): """ @@ -744,12 +757,12 @@ def test_blind_tuple_updates_from_none(self): m.save() m2 = TestTupleModel.get(partition=m.partition) - self.assertEqual(m2.int_tuple, expected) + assert m2.int_tuple == expected TestTupleModel.objects(partition=m.partition).update(int_tuple=None) m3 = TestTupleModel.get(partition=m.partition) - self.assertEqual(m3.int_tuple, None) + assert m3.int_tuple == None class TestNestedModel(Model): @@ -824,15 +837,15 @@ def test_io_success(self): m1 = TestNestedModel.create(list_list=list_list_master, map_list=map_list_master, set_tuple=set_tuple_master) m2 = TestNestedModel.get(partition=m1.partition) - self.assertIsInstance(m2.list_list, list) - self.assertIsInstance(m2.list_list[0], list) - self.assertIsInstance(m2.map_list, dict) - self.assertIsInstance(m2.map_list.get("key2"), list) + assert isinstance(m2.list_list, list) + assert isinstance(m2.list_list[0], list) + assert isinstance(m2.map_list, dict) + assert isinstance(m2.map_list.get("key2"), list) - self.assertEqual(list_list_master, m2.list_list) - self.assertEqual(map_list_master, m2.map_list) - self.assertEqual(set_tuple_master, m2.set_tuple) - self.assertIsInstance(m2.set_tuple.pop(), tuple) + assert list_list_master == m2.list_list + assert map_list_master == m2.map_list + assert set_tuple_master == m2.set_tuple + assert isinstance(m2.set_tuple.pop(), tuple) def test_type_validation(self): """ @@ -853,12 +866,18 @@ def test_type_validation(self): set_tuple_bad_tuple_value = set((("text", "text"), ("text", "text"), ("text", "text"))) set_tuple_not_set = ['This', 'is', 'not', 'a', 'set'] - self.assertRaises(ValidationError, TestNestedModel.create, **{'list_list': list_list_bad_list_context}) - self.assertRaises(ValidationError, TestNestedModel.create, **{'list_list': list_list_no_list}) - self.assertRaises(ValidationError, TestNestedModel.create, **{'map_list': map_list_bad_value}) - self.assertRaises(ValidationError, TestNestedModel.create, **{'map_list': map_list_bad_key}) - self.assertRaises(ValidationError, TestNestedModel.create, **{'set_tuple': set_tuple_bad_tuple_value}) - self.assertRaises(ValidationError, TestNestedModel.create, **{'set_tuple': set_tuple_not_set}) + with pytest.raises(ValidationError): + TestNestedModel.create(list_list=list_list_bad_list_context) + with pytest.raises(ValidationError): + TestNestedModel.create(list_list=list_list_no_list) + with pytest.raises(ValidationError): + TestNestedModel.create(map_list=map_list_bad_value) + with pytest.raises(ValidationError): + TestNestedModel.create(map_list=map_list_bad_key) + with pytest.raises(ValidationError): + TestNestedModel.create(set_tuple=set_tuple_bad_tuple_value) + with pytest.raises(ValidationError): + TestNestedModel.create(set_tuple=set_tuple_not_set) def test_instantiation_with_column_class(self): """ @@ -875,11 +894,11 @@ def test_instantiation_with_column_class(self): map_list = columns.Map(columns.Text, columns.List(columns.Text), required=False) set_tuple = columns.Set(columns.Tuple(columns.Integer, columns.Integer), required=False) - self.assertIsInstance(list_list, columns.List) - self.assertIsInstance(list_list.types[0], columns.List) - self.assertIsInstance(map_list.types[0], columns.Text) - self.assertIsInstance(map_list.types[1], columns.List) - self.assertIsInstance(set_tuple.types[0], columns.Tuple) + assert isinstance(list_list, columns.List) + assert isinstance(list_list.types[0], columns.List) + assert isinstance(map_list.types[0], columns.Text) + assert isinstance(map_list.types[1], columns.List) + assert isinstance(set_tuple.types[0], columns.Tuple) def test_default_empty_container_saving(self): """ @@ -902,9 +921,9 @@ def test_default_empty_container_saving(self): TestNestedModel.create(partition=pkey) m = TestNestedModel.get(partition=pkey) - self.assertEqual(m.list_list, list_list_master) - self.assertEqual(m.map_list, map_list_master) - self.assertEqual(m.set_tuple, set_tuple_master) + assert m.list_list == list_list_master + assert m.map_list == map_list_master + assert m.set_tuple == set_tuple_master def test_updates(self): """ @@ -931,9 +950,9 @@ def test_updates(self): m1.save() m2 = TestNestedModel.get(partition=m1.partition) - self.assertEqual(m2.list_list, list_list_replacement) - self.assertEqual(m2.map_list, map_list_replacement) - self.assertEqual(m2.set_tuple, set_tuple_replacement) + assert m2.list_list == list_list_replacement + assert m2.map_list == map_list_replacement + assert m2.set_tuple == set_tuple_replacement def test_update_from_non_empty_to_empty(self): """ @@ -955,9 +974,9 @@ def test_update_from_non_empty_to_empty(self): tmp.update() tmp = TestNestedModel.get(partition=tmp.partition) - self.assertEqual(tmp.list_list, []) - self.assertEqual(tmp.map_list, {}) - self.assertEqual(tmp.set_tuple, set()) + assert tmp.list_list == [] + assert tmp.map_list == {} + assert tmp.set_tuple == set() def test_insert_none(self): """ @@ -971,8 +990,6 @@ def test_insert_none(self): """ pkey = uuid4() tmp = TestNestedModel.create(partition=pkey, list_list=(None), map_list=(None), set_tuple=(None)) - self.assertEqual([], tmp.list_list) - self.assertEqual({}, tmp.map_list) - self.assertEqual(set(), tmp.set_tuple) - - + assert [] == tmp.list_list + assert {} == tmp.map_list + assert set() == tmp.set_tuple diff --git a/tests/integration/cqlengine/columns/test_counter_column.py b/tests/integration/cqlengine/columns/test_counter_column.py index 160b98d7c2..5f69475b34 100644 --- a/tests/integration/cqlengine/columns/test_counter_column.py +++ b/tests/integration/cqlengine/columns/test_counter_column.py @@ -13,6 +13,7 @@ # limitations under the License. from uuid import uuid4 +import pytest from cassandra.cqlengine import columns from cassandra.cqlengine.management import sync_table, drop_table @@ -32,37 +33,28 @@ class TestClassConstruction(BaseCassEngTestCase): def test_defining_a_non_counter_column_fails(self): """ Tests that defining a non counter column field in a model with a counter column fails """ - try: + with pytest.raises(ModelDefinitionException): class model(Model): partition = columns.UUID(primary_key=True, default=uuid4) counter = columns.Counter() text = columns.Text() - self.fail("did not raise expected ModelDefinitionException") - except ModelDefinitionException: - pass def test_defining_a_primary_key_counter_column_fails(self): """ Tests that defining primary keys on counter columns fails """ - try: + with pytest.raises(TypeError): class model(Model): partition = columns.UUID(primary_key=True, default=uuid4) cluster = columns.Counter(primary_ley=True) counter = columns.Counter() - self.fail("did not raise expected TypeError") - except TypeError: - pass # force it - try: + with pytest.raises(ModelDefinitionException): class model(Model): partition = columns.UUID(primary_key=True, default=uuid4) cluster = columns.Counter() cluster.primary_key = True counter = columns.Counter() - self.fail("did not raise expected ModelDefinitionException") - except ModelDefinitionException: - pass class TestCounterColumn(BaseCassEngTestCase): @@ -120,12 +112,12 @@ def test_save_after_no_update(self): # read back instance = TestCounterModel.get(partition=instance.partition) - self.assertEqual(instance.counter, expected_value) + assert instance.counter == expected_value # save after doing nothing instance.save() - self.assertEqual(instance.counter, expected_value) + assert instance.counter == expected_value # make sure there was no increment instance = TestCounterModel.get(partition=instance.partition) - self.assertEqual(instance.counter, expected_value) + assert instance.counter == expected_value diff --git a/tests/integration/cqlengine/columns/test_validation.py b/tests/integration/cqlengine/columns/test_validation.py index 32f20d52ff..ebffc0666c 100644 --- a/tests/integration/cqlengine/columns/test_validation.py +++ b/tests/integration/cqlengine/columns/test_validation.py @@ -33,6 +33,7 @@ from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthanorequalcass30, greaterthanorequalcass3_11 from tests.integration.cqlengine.base import BaseCassEngTestCase +import pytest class TestDatetime(BaseCassEngTestCase): @@ -53,7 +54,7 @@ def test_datetime_io(self): now = datetime.now() self.DatetimeTest.objects.create(test_id=0, created_at=now) dt2 = self.DatetimeTest.objects(test_id=0).first() - self.assertEqual(dt2.created_at.timetuple()[:6], now.timetuple()[:6]) + assert dt2.created_at.timetuple()[:6] == now.timetuple()[:6] def test_datetime_tzinfo_io(self): class TZ(tzinfo): @@ -65,45 +66,45 @@ def dst(self, date_time): now = datetime(1982, 1, 1, tzinfo=TZ()) dt = self.DatetimeTest.objects.create(test_id=1, created_at=now) dt2 = self.DatetimeTest.objects(test_id=1).first() - self.assertEqual(dt2.created_at.timetuple()[:6], (now + timedelta(hours=1)).timetuple()[:6]) + assert dt2.created_at.timetuple()[:6] == (now + timedelta(hours=1)).timetuple()[:6] @greaterthanorequalcass30 def test_datetime_date_support(self): today = date.today() self.DatetimeTest.objects.create(test_id=2, created_at=today) dt2 = self.DatetimeTest.objects(test_id=2).first() - self.assertEqual(dt2.created_at.isoformat(), datetime(today.year, today.month, today.day).isoformat()) + assert dt2.created_at.isoformat() == datetime(today.year, today.month, today.day).isoformat() result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2).first() - self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + assert result.created_at == datetime.combine(today, datetime.min.time()) result = self.DatetimeTest.objects.all().allow_filtering().filter(test_id=2, created_at=today).first() - self.assertEqual(result.created_at, datetime.combine(today, datetime.min.time())) + assert result.created_at == datetime.combine(today, datetime.min.time()) def test_datetime_none(self): dt = self.DatetimeTest.objects.create(test_id=3, created_at=None) dt2 = self.DatetimeTest.objects(test_id=3).first() - self.assertIsNone(dt2.created_at) + assert dt2.created_at is None dts = self.DatetimeTest.objects.filter(test_id=3).values_list('created_at') - self.assertIsNone(dts[0][0]) + assert dts[0][0] is None def test_datetime_invalid(self): dt_value= 'INVALID' - with self.assertRaises(TypeError): + with pytest.raises(TypeError): self.DatetimeTest.objects.create(test_id=4, created_at=dt_value) def test_datetime_timestamp(self): dt_value = 1454520554 self.DatetimeTest.objects.create(test_id=5, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=5).first() - self.assertEqual(dt2.created_at, datetime.fromtimestamp(dt_value, tz=timezone.utc).replace(tzinfo=None)) + assert dt2.created_at == datetime.fromtimestamp(dt_value, tz=timezone.utc).replace(tzinfo=None) def test_datetime_large(self): dt_value = datetime(2038, 12, 31, 10, 10, 10, 123000) self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=6).first() - self.assertEqual(dt2.created_at, dt_value) + assert dt2.created_at == dt_value def test_datetime_truncate_microseconds(self): """ @@ -123,7 +124,7 @@ def test_datetime_truncate_microseconds(self): dt_truncated = datetime(2024, 12, 31, 10, 10, 10, 923000) self.DatetimeTest.objects.create(test_id=6, created_at=dt_value) dt2 = self.DatetimeTest.objects(test_id=6).first() - self.assertEqual(dt2.created_at,dt_truncated) + assert dt2.created_at == dt_truncated finally: # We need to always return behavior to default DateTime.truncate_microseconds = False @@ -141,9 +142,9 @@ def setUpClass(cls): def test_default_is_set(self): tmp = self.BoolDefaultValueTest.create(test_id=1) - self.assertEqual(True, tmp.stuff) + assert True == tmp.stuff tmp2 = self.BoolDefaultValueTest.get(test_id=1) - self.assertEqual(True, tmp2.stuff) + assert True == tmp2.stuff class TestBoolValidation(BaseCassEngTestCase): @@ -160,7 +161,7 @@ def test_validation_preserves_none(self): test_obj = self.BoolValidationTest(test_id=1) test_obj.validate() - self.assertIsNone(test_obj.bool_column) + assert test_obj.bool_column is None class TestVarInt(BaseCassEngTestCase): @@ -183,9 +184,9 @@ def test_varint_io(self): long_int = 92834902384092834092384028340283048239048203480234823048230482304820348239 int1 = self.VarIntTest.objects.create(test_id=0, bignum=long_int) int2 = self.VarIntTest.objects(test_id=0).first() - self.assertEqual(int1.bignum, int2.bignum) + assert int1.bignum == int2.bignum - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): self.VarIntTest.objects.create(test_id=0, bignum="not_a_number") @@ -222,10 +223,10 @@ def _check_value_is_correct_in_db(self, value): """ if value is None: result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() - self.assertIsNone(result.class_param) + assert result.class_param is None result = self.model_class.objects(test_id=0).first() - self.assertIsNone(result.class_param) + assert result.class_param is None else: if not isinstance(value, self.python_klass): @@ -234,16 +235,16 @@ def _check_value_is_correct_in_db(self, value): value_to_compare = value result = self.model_class.objects(test_id=0).first() - self.assertIsInstance(result.class_param, self.python_klass) - self.assertEqual(result.class_param, value_to_compare) + assert isinstance(result.class_param, self.python_klass) + assert result.class_param == value_to_compare result = self.model_class.objects.all().allow_filtering().filter(test_id=0).first() - self.assertIsInstance(result.class_param, self.python_klass) - self.assertEqual(result.class_param, value_to_compare) + assert isinstance(result.class_param, self.python_klass) + assert result.class_param == value_to_compare result = self.model_class.objects.all().allow_filtering().filter(test_id=0, class_param=value).first() - self.assertIsInstance(result.class_param, self.python_klass) - self.assertEqual(result.class_param, value_to_compare) + assert isinstance(result.class_param, self.python_klass) + assert result.class_param == value_to_compare return result @@ -276,10 +277,10 @@ def test_param_none(self): """ self.model_class.objects.create(test_id=1, class_param=None) dt2 = self.model_class.objects(test_id=1).first() - self.assertIsNone(dt2.class_param) + assert dt2.class_param is None dts = self.model_class.objects(test_id=1).values_list('class_param') - self.assertIsNone(dts[0][0]) + assert dts[0][0] is None class TestDate(DataType, BaseCassEngTestCase): @@ -541,22 +542,22 @@ def test_min_length(self): Ascii(min_length=5).validate('kevin') Ascii(min_length=5).validate('kevintastic') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(min_length=1).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(min_length=1).validate(None) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(min_length=6).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(min_length=6).validate(None) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(min_length=6).validate('kevin') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ascii(min_length=-1) def test_max_length(self): @@ -573,13 +574,13 @@ def test_max_length(self): Ascii(max_length=5).validate('b') Ascii(max_length=5).validate('blake') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(max_length=0).validate('b') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(max_length=5).validate('blaketastic') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ascii(max_length=-1) def test_length_range(self): @@ -588,10 +589,10 @@ def test_length_range(self): Ascii(min_length=10, max_length=10) Ascii(min_length=10, max_length=11) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ascii(min_length=10, max_length=9) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ascii(min_length=1, max_length=0) def test_type_checking(self): @@ -599,26 +600,26 @@ def test_type_checking(self): Ascii().validate(u'unicode') Ascii().validate(bytearray('bytearray', encoding='ascii')) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii().validate(5) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii().validate(True) Ascii().validate("!#$%&\'()*+,-./") - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii().validate('Beyonc' + chr(233)) if sys.version_info < (3, 1): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii().validate(u'Beyonc' + unichr(233)) def test_unaltering_validation(self): """ Test the validation step doesn't re-interpret values. """ - self.assertEqual(Ascii().validate(''), '') - self.assertEqual(Ascii().validate(None), None) - self.assertEqual(Ascii().validate('yo'), 'yo') + assert Ascii().validate('') == '' + assert Ascii().validate(None) == None + assert Ascii().validate('yo') == 'yo' def test_non_required_validation(self): """ Tests that validation is ok on none and blank values if required is False. """ @@ -629,26 +630,26 @@ def test_required_validation(self): """ Tests that validation raise on none and blank values if value required. """ Ascii(required=True).validate('k') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(required=True).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(required=True).validate(None) # With min_length set. Ascii(required=True, min_length=0).validate('k') Ascii(required=True, min_length=1).validate('k') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(required=True, min_length=2).validate('k') # With max_length set. Ascii(required=True, max_length=1).validate('k') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Ascii(required=True, max_length=2).validate('kevin') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Ascii(required=True, max_length=0) @@ -668,22 +669,22 @@ def test_min_length(self): Text(min_length=5).validate('blake') Text(min_length=5).validate('blaketastic') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(min_length=1).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(min_length=1).validate(None) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(min_length=6).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(min_length=6).validate(None) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(min_length=6).validate('blake') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Text(min_length=-1) def test_max_length(self): @@ -700,13 +701,13 @@ def test_max_length(self): Text(max_length=5).validate('b') Text(max_length=5).validate('blake') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(max_length=0).validate('b') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(max_length=5).validate('blaketastic') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Text(max_length=-1) def test_length_range(self): @@ -715,10 +716,10 @@ def test_length_range(self): Text(min_length=10, max_length=10) Text(min_length=10, max_length=11) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Text(min_length=10, max_length=9) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Text(min_length=1, max_length=0) def test_type_checking(self): @@ -726,10 +727,10 @@ def test_type_checking(self): Text().validate(u'unicode') Text().validate(bytearray('bytearray', encoding='ascii')) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text().validate(5) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text().validate(True) Text().validate("!#$%&\'()*+,-./") @@ -739,9 +740,9 @@ def test_type_checking(self): def test_unaltering_validation(self): """ Test the validation step doesn't re-interpret values. """ - self.assertEqual(Text().validate(''), '') - self.assertEqual(Text().validate(None), None) - self.assertEqual(Text().validate('yo'), 'yo') + assert Text().validate('') == '' + assert Text().validate(None) == None + assert Text().validate('yo') == 'yo' def test_non_required_validation(self): """ Tests that validation is ok on none and blank values if required is False """ @@ -752,26 +753,26 @@ def test_required_validation(self): """ Tests that validation raise on none and blank values if value required. """ Text(required=True).validate('b') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(required=True).validate('') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(required=True).validate(None) # With min_length set. Text(required=True, min_length=0).validate('b') Text(required=True, min_length=1).validate('b') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(required=True, min_length=2).validate('b') # With max_length set. Text(required=True, max_length=1).validate('b') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): Text(required=True, max_length=2).validate('blake') - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Text(required=True, max_length=0) @@ -781,7 +782,7 @@ class TestModel(Model): id = UUID(primary_key=True, default=uuid4) def test_extra_field(self): - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): self.TestModel.create(bacon=5000) @@ -834,5 +835,5 @@ def test_inet_saves(self): def test_non_address_fails(self): # TODO: presently this only tests that the server blows it up. Is there supposed to be local validation? - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): self.InetTestModel.create(address="what is going on here?") diff --git a/tests/integration/cqlengine/connections/test_connection.py b/tests/integration/cqlengine/connections/test_connection.py index c63836785e..78d5133e63 100644 --- a/tests/integration/cqlengine/connections/test_connection.py +++ b/tests/integration/cqlengine/connections/test_connection.py @@ -42,12 +42,12 @@ def tearDown(self): @local def test_connection_setup_with_setup(self): connection.setup(hosts=None, default_keyspace=None) - self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + assert connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") is not None @local def test_connection_setup_with_default(self): connection.default() - self.assertIsNotNone(connection.get_connection("default").cluster.metadata.get_host("127.0.0.1")) + assert connection.get_connection("default").cluster.metadata.get_host("127.0.0.1") is not None def test_only_one_connection_is_created(self): """ @@ -63,7 +63,7 @@ def test_only_one_connection_is_created(self): number_of_clusters_before = len(_clusters_for_shutdown) connection.default() number_of_clusters_after = len(_clusters_for_shutdown) - self.assertEqual(number_of_clusters_after - number_of_clusters_before, 1) + assert number_of_clusters_after - number_of_clusters_before == 1 class SeveralConnectionsTest(BaseCassEngTestCase): @@ -119,11 +119,11 @@ def test_connection_session_switch(self): sync_table(TestConnectModel) TCM2 = TestConnectModel.create(id=1, keyspace=self.keyspace2) connection.set_session(self.session1) - self.assertEqual(1, TestConnectModel.objects.count()) - self.assertEqual(TestConnectModel.objects.first(), TCM1) + assert 1 == TestConnectModel.objects.count() + assert TestConnectModel.objects.first() == TCM1 connection.set_session(self.session2) - self.assertEqual(1, TestConnectModel.objects.count()) - self.assertEqual(TestConnectModel.objects.first(), TCM2) + assert 1 == TestConnectModel.objects.count() + assert TestConnectModel.objects.first() == TCM2 class ConnectionModel(Model): @@ -135,7 +135,7 @@ class ConnectionInitTest(unittest.TestCase): def test_default_connection_uses_legacy(self): connection.default() conn = connection.get_connection() - self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + assert conn.cluster._config_mode == _ConfigMode.LEGACY def test_connection_with_legacy_settings(self): connection.setup( @@ -144,7 +144,7 @@ def test_connection_with_legacy_settings(self): consistency=ConsistencyLevel.LOCAL_ONE ) conn = connection.get_connection() - self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + assert conn.cluster._config_mode == _ConfigMode.LEGACY def test_connection_from_session_with_execution_profile(self): cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) @@ -152,7 +152,7 @@ def test_connection_from_session_with_execution_profile(self): connection.default() connection.set_session(session) conn = connection.get_connection() - self.assertEqual(conn.cluster._config_mode, _ConfigMode.PROFILES) + assert conn.cluster._config_mode == _ConfigMode.PROFILES def test_connection_from_session_with_legacy_settings(self): cluster = TestCluster(load_balancing_policy=RoundRobinPolicy()) @@ -160,7 +160,7 @@ def test_connection_from_session_with_legacy_settings(self): session.row_factory = dict_factory connection.set_session(session) conn = connection.get_connection() - self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + assert conn.cluster._config_mode == _ConfigMode.LEGACY def test_uncommitted_session_uses_legacy(self): cluster = TestCluster() @@ -168,7 +168,7 @@ def test_uncommitted_session_uses_legacy(self): session.row_factory = dict_factory connection.set_session(session) conn = connection.get_connection() - self.assertEqual(conn.cluster._config_mode, _ConfigMode.LEGACY) + assert conn.cluster._config_mode == _ConfigMode.LEGACY def test_legacy_insert_query(self): connection.setup( @@ -176,21 +176,21 @@ def test_legacy_insert_query(self): default_keyspace=DEFAULT_KEYSPACE, consistency=ConsistencyLevel.LOCAL_ONE ) - self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.LEGACY) + assert connection.get_connection().cluster._config_mode == _ConfigMode.LEGACY sync_table(ConnectionModel) ConnectionModel.objects.create(key=0, some_data='text0') ConnectionModel.objects.create(key=1, some_data='text1') - self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') + assert ConnectionModel.objects(key=0)[0].some_data == 'text0' def test_execution_profile_insert_query(self): cluster = TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(row_factory=dict_factory)}) session = cluster.connect() connection.default() connection.set_session(session) - self.assertEqual(connection.get_connection().cluster._config_mode, _ConfigMode.PROFILES) + assert connection.get_connection().cluster._config_mode == _ConfigMode.PROFILES sync_table(ConnectionModel) ConnectionModel.objects.create(key=0, some_data='text0') ConnectionModel.objects.create(key=1, some_data='text1') - self.assertEqual(ConnectionModel.objects(key=0)[0].some_data, 'text0') + assert ConnectionModel.objects(key=0)[0].some_data == 'text0' diff --git a/tests/integration/cqlengine/management/test_compaction_settings.py b/tests/integration/cqlengine/management/test_compaction_settings.py index d58c419d0e..25484b30c9 100644 --- a/tests/integration/cqlengine/management/test_compaction_settings.py +++ b/tests/integration/cqlengine/management/test_compaction_settings.py @@ -20,6 +20,7 @@ from cassandra.cqlengine.models import Model from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.util import assertRegex class LeveledCompactionTestTable(Model): @@ -53,7 +54,7 @@ class LeveledCompactionChangesDetectionTest(Model): drop_table(LeveledCompactionChangesDetectionTest) sync_table(LeveledCompactionChangesDetectionTest) - self.assertFalse(_update_options(LeveledCompactionChangesDetectionTest)) + assert not _update_options(LeveledCompactionChangesDetectionTest) def test_compaction_not_altered_without_changes_sizetiered(self): class SizeTieredCompactionChangesDetectionTest(Model): @@ -71,7 +72,7 @@ class SizeTieredCompactionChangesDetectionTest(Model): drop_table(SizeTieredCompactionChangesDetectionTest) sync_table(SizeTieredCompactionChangesDetectionTest) - self.assertFalse(_update_options(SizeTieredCompactionChangesDetectionTest)) + assert not _update_options(SizeTieredCompactionChangesDetectionTest) def test_alter_actually_alters(self): tmp = copy.deepcopy(LeveledCompactionTestTable) @@ -82,7 +83,7 @@ def test_alter_actually_alters(self): table_meta = _get_table_metadata(tmp) - self.assertRegex(table_meta.export_as_string(), '.*SizeTieredCompactionStrategy.*') + assertRegex(table_meta.export_as_string(), '.*SizeTieredCompactionStrategy.*') def test_alter_options(self): @@ -96,11 +97,11 @@ class AlterTable(Model): drop_table(AlterTable) sync_table(AlterTable) table_meta = _get_table_metadata(AlterTable) - self.assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '64'.*") + assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '64'.*") AlterTable.__options__['compaction']['sstable_size_in_mb'] = '128' sync_table(AlterTable) table_meta = _get_table_metadata(AlterTable) - self.assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '128'.*") + assertRegex(table_meta.export_as_string(), ".*'sstable_size_in_mb': '128'.*") class OptionsTest(BaseCassEngTestCase): @@ -110,7 +111,7 @@ def _verify_options(self, table_meta, expected_options): for name, value in expected_options.items(): if isinstance(value, str): - self.assertIn("%s = '%s'" % (name, value), cql) + assert "%s = '%s'" % (name, value) in cql else: start = cql.find("%s = {" % (name,)) end = cql.find('}', start) @@ -124,9 +125,9 @@ def _verify_options(self, table_meta, expected_options): attr = "'%s': '%s'" % (subname, subvalue.split('.')[-1]) found_at = cql.find(attr, start) else: - - self.assertTrue(found_at > start) - self.assertTrue(found_at < end) + + assert found_at > start + assert found_at < end def test_all_size_tiered_options(self): class AllSizeTieredOptionsModel(Model): diff --git a/tests/integration/cqlengine/management/test_management.py b/tests/integration/cqlengine/management/test_management.py index 55fb62f22c..1332680cef 100644 --- a/tests/integration/cqlengine/management/test_management.py +++ b/tests/integration/cqlengine/management/test_management.py @@ -29,6 +29,7 @@ from tests.integration.cqlengine.query.test_queryset import TestModel from cassandra.cqlengine.usertype import UserType from tests.integration.cqlengine import DEFAULT_KEYSPACE +import pytest INCLUDE_REPAIR = (not CASSANDRA_VERSION >= Version('4-a')) and SCYLLA_VERSION is None # This should cover DSE 6.0+ @@ -39,20 +40,20 @@ def test_create_drop_succeeeds(self): cluster = get_cluster() keyspace_ss = 'test_ks_ss' - self.assertNotIn(keyspace_ss, cluster.metadata.keyspaces) + assert keyspace_ss not in cluster.metadata.keyspaces management.create_keyspace_simple(keyspace_ss, 2) - self.assertIn(keyspace_ss, cluster.metadata.keyspaces) + assert keyspace_ss in cluster.metadata.keyspaces management.drop_keyspace(keyspace_ss) - self.assertNotIn(keyspace_ss, cluster.metadata.keyspaces) + assert keyspace_ss not in cluster.metadata.keyspaces keyspace_nts = 'test_ks_nts' - self.assertNotIn(keyspace_nts, cluster.metadata.keyspaces) + assert keyspace_nts not in cluster.metadata.keyspaces management.create_keyspace_network_topology(keyspace_nts, {'dc1': 1}) - self.assertIn(keyspace_nts, cluster.metadata.keyspaces) + assert keyspace_nts in cluster.metadata.keyspaces management.drop_keyspace(keyspace_nts) - self.assertNotIn(keyspace_nts, cluster.metadata.keyspaces) + assert keyspace_nts not in cluster.metadata.keyspaces class DropTableTest(BaseCassEngTestCase): @@ -177,30 +178,30 @@ def setUp(self): def test_add_column(self): sync_table(FirstModel) meta_columns = _get_table_metadata(FirstModel).columns - self.assertEqual(set(meta_columns), set(FirstModel._columns)) + assert set(meta_columns) == set(FirstModel._columns) sync_table(SecondModel) meta_columns = _get_table_metadata(FirstModel).columns - self.assertEqual(set(meta_columns), set(SecondModel._columns)) + assert set(meta_columns) == set(SecondModel._columns) sync_table(ThirdModel) meta_columns = _get_table_metadata(FirstModel).columns - self.assertEqual(len(meta_columns), 5) - self.assertEqual(len(ThirdModel._columns), 4) - self.assertIn('fourth_key', meta_columns) - self.assertNotIn('fourth_key', ThirdModel._columns) - self.assertIn('blah', ThirdModel._columns) - self.assertIn('blah', meta_columns) + assert len(meta_columns) == 5 + assert len(ThirdModel._columns) == 4 + assert 'fourth_key' in meta_columns + assert 'fourth_key' not in ThirdModel._columns + assert 'blah' in ThirdModel._columns + assert 'blah' in meta_columns sync_table(FourthModel) meta_columns = _get_table_metadata(FirstModel).columns - self.assertEqual(len(meta_columns), 5) - self.assertEqual(len(ThirdModel._columns), 4) - self.assertIn('fourth_key', meta_columns) - self.assertNotIn('fourth_key', FourthModel._columns) - self.assertIn('renamed', FourthModel._columns) - self.assertNotIn('renamed', meta_columns) - self.assertIn('blah', meta_columns) + assert len(meta_columns) == 5 + assert len(ThirdModel._columns) == 4 + assert 'fourth_key' in meta_columns + assert 'fourth_key' not in FourthModel._columns + assert 'renamed' in FourthModel._columns + assert 'renamed' not in meta_columns + assert 'blah' in meta_columns class ModelWithTableProperties(Model): @@ -239,8 +240,7 @@ def test_set_table_properties(self): expected.update({'read_repair_chance': 0.17985}) options = management._get_table_metadata(ModelWithTableProperties).options - self.assertEqual(dict([(k, options.get(k)) for k in expected.keys()]), - expected) + assert dict([(k, options.get(k)) for k in expected.keys()]) == expected def test_table_property_update(self): ModelWithTableProperties.__options__['bloom_filter_fp_chance'] = 0.66778 @@ -255,14 +255,15 @@ def test_table_property_update(self): table_options = management._get_table_metadata(ModelWithTableProperties).options - self.assertLessEqual(ModelWithTableProperties.__options__.items(), table_options.items()) + assert ModelWithTableProperties.__options__.items() <= table_options.items() def test_bogus_option_update(self): sync_table(ModelWithTableProperties) option = 'no way will this ever be an option' try: ModelWithTableProperties.__options__[option] = 'what was I thinking?' - self.assertRaisesRegex(KeyError, "Invalid table option.*%s.*" % option, sync_table, ModelWithTableProperties) + with pytest.raises(KeyError, match="Invalid table option.*%s.*" % option): + sync_table(ModelWithTableProperties) finally: ModelWithTableProperties.__options__.pop(option, None) @@ -278,14 +279,14 @@ def test_sync_table_works_with_primary_keys_only_tables(self): # blows up with DoesNotExist if table does not exist table_meta = management._get_table_metadata(PrimaryKeysOnlyModel) - self.assertIn('LeveledCompactionStrategy', table_meta.as_cql_query()) + assert 'LeveledCompactionStrategy' in table_meta.as_cql_query() PrimaryKeysOnlyModel.__options__['compaction']['class'] = 'SizeTieredCompactionStrategy' sync_table(PrimaryKeysOnlyModel) table_meta = management._get_table_metadata(PrimaryKeysOnlyModel) - self.assertIn('SizeTieredCompactionStrategy', table_meta.as_cql_query()) + assert 'SizeTieredCompactionStrategy' in table_meta.as_cql_query() def test_primary_key_validation(self): """ @@ -298,9 +299,12 @@ def test_primary_key_validation(self): @test_category object_mapper """ sync_table(PrimaryKeysOnlyModel) - self.assertRaises(CQLEngineException, sync_table, PrimaryKeysModelChanged) - self.assertRaises(CQLEngineException, sync_table, PrimaryKeysAddedClusteringKey) - self.assertRaises(CQLEngineException, sync_table, PrimaryKeysRemovedPk) + with pytest.raises(CQLEngineException): + sync_table(PrimaryKeysModelChanged) + with pytest.raises(CQLEngineException): + sync_table(PrimaryKeysAddedClusteringKey) + with pytest.raises(CQLEngineException): + sync_table(PrimaryKeysRemovedPk) class IndexModel(Model): @@ -364,12 +368,12 @@ def test_sync_warnings(self): with MockLoggingHandler().set_module_name(management.__name__) as mock_handler: sync_table(BaseInconsistent) sync_table(ChangedInconsistent) - self.assertTrue('differing from the model type' in mock_handler.messages.get('warning')[0]) + assert 'differing from the model type' in mock_handler.messages.get('warning')[0] if CASSANDRA_VERSION >= Version('2.1'): sync_type(DEFAULT_KEYSPACE, BaseInconsistentType) mock_handler.reset() sync_type(DEFAULT_KEYSPACE, ChangedInconsistentType) - self.assertTrue('differing from the model user type' in mock_handler.messages.get('warning')[0]) + assert 'differing from the model user type' in mock_handler.messages.get('warning')[0] class TestIndexSetModel(Model): @@ -401,12 +405,12 @@ def test_sync_index(self): """ sync_table(IndexModel) table_meta = management._get_table_metadata(IndexModel) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + assert management._get_index_name_by_column(table_meta, 'second_key') is not None # index already exists sync_table(IndexModel) table_meta = management._get_table_metadata(IndexModel) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + assert management._get_index_name_by_column(table_meta, 'second_key') is not None def test_sync_index_case_sensitive(self): """ @@ -421,12 +425,12 @@ def test_sync_index_case_sensitive(self): """ sync_table(IndexCaseSensitiveModel) table_meta = management._get_table_metadata(IndexCaseSensitiveModel) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + assert management._get_index_name_by_column(table_meta, 'second_key') is not None # index already exists sync_table(IndexCaseSensitiveModel) table_meta = management._get_table_metadata(IndexCaseSensitiveModel) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'second_key')) + assert management._get_index_name_by_column(table_meta, 'second_key') is not None @greaterthancass20 @requires_collection_indexes @@ -443,10 +447,10 @@ def test_sync_indexed_set(self): """ sync_table(TestIndexSetModel) table_meta = management._get_table_metadata(TestIndexSetModel) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_set')) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'int_list')) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'text_map')) - self.assertIsNotNone(management._get_index_name_by_column(table_meta, 'mixed_tuple')) + assert management._get_index_name_by_column(table_meta, 'int_set') is not None + assert management._get_index_name_by_column(table_meta, 'int_list') is not None + assert management._get_index_name_by_column(table_meta, 'text_map') is not None + assert management._get_index_name_by_column(table_meta, 'mixed_tuple') is not None class NonModelFailureTest(BaseCassEngTestCase): @@ -454,7 +458,7 @@ class FakeModel(object): pass def test_failure(self): - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): sync_table(self.FakeModel) @@ -475,9 +479,9 @@ class StaticModel(Model): with mock.patch.object(session, "execute", wraps=session.execute) as m: sync_table(StaticModel) - self.assertGreater(m.call_count, 0) + assert m.call_count > 0 statement = m.call_args[0][0].query_string - self.assertIn('"name" text static', statement) + assert '"name" text static' in statement # if we sync again, we should not apply an alter w/ a static sync_table(StaticModel) @@ -485,4 +489,4 @@ class StaticModel(Model): with mock.patch.object(session, "execute", wraps=session.execute) as m2: sync_table(StaticModel) - self.assertEqual(len(m2.call_args_list), 0) + assert len(m2.call_args_list) == 0 diff --git a/tests/integration/cqlengine/model/test_class_construction.py b/tests/integration/cqlengine/model/test_class_construction.py index dae97c4438..df0a57d543 100644 --- a/tests/integration/cqlengine/model/test_class_construction.py +++ b/tests/integration/cqlengine/model/test_class_construction.py @@ -20,6 +20,7 @@ from cassandra.cqlengine.query import ModelQuerySet, DMLQuery from tests.integration.cqlengine.base import BaseCassEngTestCase +import pytest class TestModelClassFunction(BaseCassEngTestCase): @@ -39,16 +40,16 @@ class TestModel(Model): text = columns.Text() # check class attibutes - self.assertHasAttr(TestModel, '_columns') - self.assertHasAttr(TestModel, 'id') - self.assertHasAttr(TestModel, 'text') + assert hasattr(TestModel, '_columns') + assert hasattr(TestModel, 'id') + assert hasattr(TestModel, 'text') # check instance attributes inst = TestModel() - self.assertHasAttr(inst, 'id') - self.assertHasAttr(inst, 'text') - self.assertIsNotNone(inst.id) - self.assertIsNone(inst.text) + assert hasattr(inst, 'id') + assert hasattr(inst, 'text') + assert inst.id is not None + assert inst.text is None def test_values_on_instantiation(self): """ @@ -61,15 +62,15 @@ class TestPerson(Model): # Check that defaults are available at instantiation. inst1 = TestPerson() - self.assertHasAttr(inst1, 'first_name') - self.assertHasAttr(inst1, 'last_name') - self.assertEqual(inst1.first_name, 'kevin') - self.assertEqual(inst1.last_name, 'deldycke') + assert hasattr(inst1, 'first_name') + assert hasattr(inst1, 'last_name') + assert inst1.first_name == 'kevin' + assert inst1.last_name == 'deldycke' # Check that values on instantiation overrides defaults. inst2 = TestPerson(first_name='bob', last_name='joe') - self.assertEqual(inst2.first_name, 'bob') - self.assertEqual(inst2.last_name, 'joe') + assert inst2.first_name == 'bob' + assert inst2.last_name == 'joe' def test_db_map(self): """ @@ -83,15 +84,15 @@ class WildDBNames(Model): numbers = columns.Integer(db_field='integers_etc') db_map = WildDBNames._db_map - self.assertEqual(db_map['words_and_whatnot'], 'content') - self.assertEqual(db_map['integers_etc'], 'numbers') + assert db_map['words_and_whatnot'] == 'content' + assert db_map['integers_etc'] == 'numbers' def test_attempting_to_make_duplicate_column_names_fails(self): """ Tests that trying to create conflicting db column names will fail """ - with self.assertRaisesRegex(ModelException, r".*more than once$"): + with pytest.raises(ModelException, match=r".*more than once$"): class BadNames(Model): words = columns.Text(primary_key=True) content = columns.Text(db_field='words') @@ -108,10 +109,10 @@ class Stuff(Model): content = columns.Text() numbers = columns.Integer() - self.assertEqual([x for x in Stuff._columns.keys()], ['id', 'words', 'content', 'numbers']) + assert [x for x in Stuff._columns.keys()] == ['id', 'words', 'content', 'numbers'] def test_exception_raised_when_creating_class_without_pk(self): - with self.assertRaises(ModelDefinitionException): + with pytest.raises(ModelDefinitionException): class TestModel(Model): count = columns.Integer() @@ -129,9 +130,9 @@ class Stuff(Model): inst1 = Stuff(num=5) inst2 = Stuff(num=7) - self.assertNotEqual(inst1.num, inst2.num) - self.assertEqual(inst1.num, 5) - self.assertEqual(inst2.num, 7) + assert inst1.num != inst2.num + assert inst1.num == 5 + assert inst2.num == 7 def test_superclass_fields_are_inherited(self): """ @@ -170,16 +171,16 @@ class ModelWithPartitionKeys(Model): cols = ModelWithPartitionKeys._columns - self.assertTrue(cols['c1'].primary_key) - self.assertFalse(cols['c1'].partition_key) + assert cols['c1'].primary_key + assert not cols['c1'].partition_key - self.assertTrue(cols['p1'].primary_key) - self.assertTrue(cols['p1'].partition_key) - self.assertTrue(cols['p2'].primary_key) - self.assertTrue(cols['p2'].partition_key) + assert cols['p1'].primary_key + assert cols['p1'].partition_key + assert cols['p2'].primary_key + assert cols['p2'].partition_key obj = ModelWithPartitionKeys(p1='a', p2='b') - self.assertEqual(obj.pk, ('a', 'b')) + assert obj.pk == ('a', 'b') def test_del_attribute_is_assigned_properly(self): """ Tests that columns that can be deleted have the del attribute """ @@ -191,7 +192,7 @@ class DelModel(Model): model = DelModel(key=4, data=5) del model.data - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): del model.key def test_does_not_exist_exceptions_are_not_shared_between_model(self): @@ -236,7 +237,7 @@ class NoKeyspace(Model): __abstract__ = True key = columns.UUID(primary_key=True) - self.assertEqual(len(warn), 0) + assert len(warn) == 0 class TestManualTableNaming(BaseCassEngTestCase): @@ -278,8 +279,8 @@ def test_proper_table_naming_case_insensitive(self): @test_category object_mapper """ - self.assertEqual(self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=False), 'manual_name') - self.assertEqual(self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=True), 'whatever.manual_name') + assert self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=False) == 'manual_name' + assert self.RenamedCaseInsensitiveTest.column_family_name(include_keyspace=True) == 'whatever.manual_name' def test_proper_table_naming_case_sensitive(self): """ @@ -292,8 +293,8 @@ def test_proper_table_naming_case_sensitive(self): @test_category object_mapper """ - self.assertEqual(self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=False), '"Manual_Name"') - self.assertEqual(self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=True), 'whatever."Manual_Name"') + assert self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=False) == '"Manual_Name"' + assert self.RenamedCaseSensitiveTest.column_family_name(include_keyspace=True) == 'whatever."Manual_Name"' class AbstractModel(Model): @@ -339,18 +340,18 @@ def test_abstract_attribute_is_not_inherited(self): def test_attempting_to_save_abstract_model_fails(self): """ Attempting to save a model from an abstract model should fail """ - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): AbstractModelWithFullCols.create(pkey=1, data=2) def test_attempting_to_create_abstract_table_fails(self): """ Attempting to create a table from an abstract model should fail """ from cassandra.cqlengine.management import sync_table - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): sync_table(AbstractModelWithFullCols) def test_attempting_query_on_abstract_model_fails(self): """ Tests attempting to execute query with an abstract model fails """ - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): iter(AbstractModelWithFullCols.objects(pkey=5)).next() def test_abstract_columns_are_inherited(self): @@ -395,7 +396,7 @@ class CQModel(Model): part = columns.UUID(primary_key=True) data = columns.Text() - with self.assertRaises(self.TestException): + with pytest.raises(self.TestException): CQModel.create(part=uuid4(), data='s') def test_overriding_dmlqueryset(self): @@ -410,7 +411,7 @@ class CDQModel(Model): part = columns.UUID(primary_key=True) data = columns.Text() - with self.assertRaises(self.TestException): + with pytest.raises(self.TestException): CDQModel().save() @@ -422,4 +423,4 @@ def test_subclassing(self): class AlreadyLoadedTest(ConcreteModelWithCol): new_field = columns.Integer() - self.assertGreater(len(AlreadyLoadedTest()), length) + assert len(AlreadyLoadedTest()) > length diff --git a/tests/integration/cqlengine/model/test_model.py b/tests/integration/cqlengine/model/test_model.py index d5153843f5..cafe6ae9c9 100644 --- a/tests/integration/cqlengine/model/test_model.py +++ b/tests/integration/cqlengine/model/test_model.py @@ -22,6 +22,7 @@ from uuid import uuid1 from tests.integration import pypy from tests.integration.cqlengine.base import TestQueryUpdateModel +import pytest class TestModel(unittest.TestCase): """ Tests the non-io functionality of models """ @@ -35,8 +36,8 @@ class EqualityModel(Model): m0 = EqualityModel(pk=0) m1 = EqualityModel(pk=1) - self.assertEqual(m0, m0) - self.assertNotEqual(m0, m1) + assert m0 == m0 + assert m0 != m1 def test_model_equality(self): """ tests the model equality functionality """ @@ -51,8 +52,8 @@ class EqualityModel1(Model): m0 = EqualityModel0(pk=0) m1 = EqualityModel1(kk=1) - self.assertEqual(m0, m0) - self.assertNotEqual(m0, m1) + assert m0 == m0 + assert m0 != m1 def test_keywords_as_names(self): """ @@ -87,8 +88,8 @@ class table(Model): created = table.create(select=0, table='table') selected = table.objects(select=0)[0] - self.assertEqual(created.select, selected.select) - self.assertEqual(created.table, selected.table) + assert created.select == selected.select + assert created.table == selected.table # Alter should work class table(Model): @@ -101,9 +102,9 @@ class table(Model): created = table.create(select=1, table='table') selected = table.objects(select=1)[0] - self.assertEqual(created.select, selected.select) - self.assertEqual(created.table, selected.table) - self.assertEqual(created.where, selected.where) + assert created.select == selected.select + assert created.table == selected.table + assert created.where == selected.where drop_keyspace('keyspace') @@ -112,18 +113,19 @@ class TestModel(Model): k = columns.Integer(primary_key=True) # no model keyspace uses default - self.assertEqual(TestModel.column_family_name(), "%s.test_model" % (models.DEFAULT_KEYSPACE,)) + assert TestModel.column_family_name() == "%s.test_model" % (models.DEFAULT_KEYSPACE,) # model keyspace overrides TestModel.__keyspace__ = "my_test_keyspace" - self.assertEqual(TestModel.column_family_name(), "%s.test_model" % (TestModel.__keyspace__,)) + assert TestModel.column_family_name() == "%s.test_model" % (TestModel.__keyspace__,) # neither set should raise CQLEngineException before failing or formatting an invalid name del TestModel.__keyspace__ with patch('cassandra.cqlengine.models.DEFAULT_KEYSPACE', None): - self.assertRaises(CQLEngineException, TestModel.column_family_name) + with pytest.raises(CQLEngineException): + TestModel.column_family_name() # .. but we can still get the bare CF name - self.assertEqual(TestModel.column_family_name(include_keyspace=False), "test_model") + assert TestModel.column_family_name(include_keyspace=False) == "test_model" def test_column_family_case_sensitive(self): """ @@ -141,15 +143,16 @@ class TestModel(Model): k = columns.Integer(primary_key=True) - self.assertEqual(TestModel.column_family_name(), '%s."TestModel"' % (models.DEFAULT_KEYSPACE,)) + assert TestModel.column_family_name() == '%s."TestModel"' % (models.DEFAULT_KEYSPACE,) TestModel.__keyspace__ = "my_test_keyspace" - self.assertEqual(TestModel.column_family_name(), '%s."TestModel"' % (TestModel.__keyspace__,)) + assert TestModel.column_family_name() == '%s."TestModel"' % (TestModel.__keyspace__,) del TestModel.__keyspace__ with patch('cassandra.cqlengine.models.DEFAULT_KEYSPACE', None): - self.assertRaises(CQLEngineException, TestModel.column_family_name) - self.assertEqual(TestModel.column_family_name(include_keyspace=False), '"TestModel"') + with pytest.raises(CQLEngineException): + TestModel.column_family_name() + assert TestModel.column_family_name(include_keyspace=False) == '"TestModel"' class BuiltInAttributeConflictTest(unittest.TestCase): @@ -157,7 +160,7 @@ class BuiltInAttributeConflictTest(unittest.TestCase): def test_model_with_attribute_name_conflict(self): """should raise exception when model defines column that conflicts with built-in attribute""" - with self.assertRaises(ModelDefinitionException): + with pytest.raises(ModelDefinitionException): class IllegalTimestampColumnModel(Model): my_primary_key = columns.Integer(primary_key=True) @@ -165,7 +168,7 @@ class IllegalTimestampColumnModel(Model): def test_model_with_method_name_conflict(self): """should raise exception when model defines column that conflicts with built-in method""" - with self.assertRaises(ModelDefinitionException): + with pytest.raises(ModelDefinitionException): class IllegalFilterColumnModel(Model): my_primary_key = columns.Integer(primary_key=True) @@ -216,11 +219,11 @@ def test_comparison(self): TestQueryUpdateModel.text_list.column, TestQueryUpdateModel.text_map.column] - self.assertEqual(l, sorted(l)) - self.assertNotEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) - self.assertLessEqual(TestQueryUpdateModel.partition.column, TestQueryUpdateModel.cluster.column) - self.assertGreater(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) - self.assertGreaterEqual(TestQueryUpdateModel.cluster.column, TestQueryUpdateModel.partition.column) + assert l == sorted(l) + assert TestQueryUpdateModel.partition.column != TestQueryUpdateModel.cluster.column + assert TestQueryUpdateModel.partition.column <= TestQueryUpdateModel.cluster.column + assert TestQueryUpdateModel.cluster.column > TestQueryUpdateModel.partition.column + assert TestQueryUpdateModel.cluster.column >= TestQueryUpdateModel.partition.column class TestDeprecationWarning(unittest.TestCase): @@ -259,9 +262,7 @@ class SensitiveModel(Model): # ignore DeprecationWarning('The loop argument is deprecated since Python 3.8, and scheduled for removal in Python 3.10.') relevant_warnings = [warn for warn in w if "The loop argument is deprecated" not in str(warn.message)] - self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(relevant_warnings[0].message)) - self.assertIn("__table_name_case_sensitive__ will be removed in 4.0.", str(relevant_warnings[1].message)) - self.assertIn("ModelQuerySet indexing with negative indices support will be removed in 4.0.", - str(relevant_warnings[2].message)) - self.assertIn("ModelQuerySet slicing with negative indices support will be removed in 4.0.", - str(relevant_warnings[3].message)) + assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[0].message) + assert "__table_name_case_sensitive__ will be removed in 4.0." in str(relevant_warnings[1].message) + assert "ModelQuerySet indexing with negative indices support will be removed in 4.0." in str(relevant_warnings[2].message) + assert "ModelQuerySet slicing with negative indices support will be removed in 4.0." in str(relevant_warnings[3].message) diff --git a/tests/integration/cqlengine/model/test_model_io.py b/tests/integration/cqlengine/model/test_model_io.py index 81240e90c5..f55815310a 100644 --- a/tests/integration/cqlengine/model/test_model_io.py +++ b/tests/integration/cqlengine/model/test_model_io.py @@ -33,6 +33,7 @@ from tests.integration import PROTOCOL_VERSION, greaterthanorequalcass3_10 from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE +from tests.util import assertSetEqual class TestModel(Model): @@ -73,13 +74,13 @@ def test_model_save_and_load(self): Tests that models can be saved and retrieved, using the create method. """ tm = TestModel.create(count=8, text='123456789') - self.assertIsInstance(tm, TestModel) + assert isinstance(tm, TestModel) tm2 = TestModel.objects(id=tm.pk).first() - self.assertIsInstance(tm2, TestModel) + assert isinstance(tm2, TestModel) for cname in tm._columns.keys(): - self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + assert getattr(tm, cname) == getattr(tm2, cname) def test_model_instantiation_save_and_load(self): """ @@ -88,14 +89,14 @@ def test_model_instantiation_save_and_load(self): """ tm = TestModel(count=8, text='123456789') # Tests that values are available on instantiation. - self.assertIsNotNone(tm['id']) - self.assertEqual(tm.count, 8) - self.assertEqual(tm.text, '123456789') + assert tm['id'] is not None + assert tm.count == 8 + assert tm.text == '123456789' tm.save() tm2 = TestModel.objects(id=tm.id).first() for cname in tm._columns.keys(): - self.assertEqual(getattr(tm, cname), getattr(tm2, cname)) + assert getattr(tm, cname) == getattr(tm2, cname) def test_model_read_as_dict(self): """ @@ -108,18 +109,16 @@ def test_model_read_as_dict(self): 'text': tm.text, 'a_bool': tm.a_bool, } - self.assertEqual(sorted(tm.keys()), sorted(column_dict.keys())) + assert sorted(tm.keys()) == sorted(column_dict.keys()) - self.assertSetEqual(set(tm.values()), set(column_dict.values())) - self.assertEqual( - sorted(tm.items(), key=itemgetter(0)), - sorted(column_dict.items(), key=itemgetter(0))) - self.assertEqual(len(tm), len(column_dict)) + assertSetEqual(set(tm.values()), set(column_dict.values())) + assert sorted(tm.items(), key=itemgetter(0)) == sorted(column_dict.items(), key=itemgetter(0)) + assert len(tm) == len(column_dict) for column_id in column_dict.keys(): - self.assertEqual(tm[column_id], column_dict[column_id]) + assert tm[column_id] == column_dict[column_id] tm['count'] = 6 - self.assertEqual(tm.count, 6) + assert tm.count == 6 def test_model_updating_works_properly(self): """ @@ -132,8 +131,8 @@ def test_model_updating_works_properly(self): tm.save() tm2 = TestModel.objects(id=tm.pk).first() - self.assertEqual(tm.count, tm2.count) - self.assertEqual(tm.a_bool, tm2.a_bool) + assert tm.count == tm2.count + assert tm.a_bool == tm2.a_bool def test_model_deleting_works_properly(self): """ @@ -142,7 +141,7 @@ def test_model_deleting_works_properly(self): tm = TestModel.create(count=8, text='123456789') tm.delete() tm2 = TestModel.objects(id=tm.pk).first() - self.assertIsNone(tm2) + assert tm2 is None def test_column_deleting_works_properly(self): """ @@ -152,10 +151,10 @@ def test_column_deleting_works_properly(self): tm.save() tm2 = TestModel.objects(id=tm.pk).first() - self.assertIsInstance(tm2, TestModel) + assert isinstance(tm2, TestModel) - self.assertTrue(tm2.text is None) - self.assertTrue(tm2._values['text'].previous_value is None) + assert tm2.text is None + assert tm2._values['text'].previous_value is None def test_a_sensical_error_is_raised_if_you_try_to_create_a_table_twice(self): """ @@ -212,26 +211,26 @@ class AllDatatypesModel(Model): m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000'), o=Duration(2, 3, 4)) - self.assertEqual(1, AllDatatypesModel.objects.count()) + assert 1 == AllDatatypesModel.objects.count() output = AllDatatypesModel.objects.first() for i, i_char in enumerate(range(ord('a'), ord('a') + 14)): - self.assertEqual(input[i], output[chr(i_char)]) + assert input[i] == output[chr(i_char)] def test_can_specify_none_instead_of_default(self): - self.assertIsNotNone(TestModel.a_bool.column.default) + assert TestModel.a_bool.column.default is not None # override default inst = TestModel.create(a_bool=None) - self.assertIsNone(inst.a_bool) + assert inst.a_bool is None queried = TestModel.objects(id=inst.id).first() - self.assertIsNone(queried.a_bool) + assert queried.a_bool is None # letting default be set inst = TestModel.create() - self.assertEqual(inst.a_bool, TestModel.a_bool.column.default) + assert inst.a_bool == TestModel.a_bool.column.default queried = TestModel.objects(id=inst.id).first() - self.assertEqual(queried.a_bool, TestModel.a_bool.column.default) + assert queried.a_bool == TestModel.a_bool.column.default def test_can_insert_model_with_all_protocol_v4_column_types(self): """ @@ -265,11 +264,11 @@ class v4DatatypesModel(Model): v4DatatypesModel.create(id=0, a=date(1970, 1, 1), b=32523, c=time(16, 47, 25, 7), d=123) - self.assertEqual(1, v4DatatypesModel.objects.count()) + assert 1 == v4DatatypesModel.objects.count() output = v4DatatypesModel.objects.first() for i, i_char in enumerate(range(ord('a'), ord('a') + 3)): - self.assertEqual(input[i], output[chr(i_char)]) + assert input[i] == output[chr(i_char)] def test_can_insert_double_and_float(self): """ @@ -292,16 +291,16 @@ class FloatingPointModel(Model): FloatingPointModel.create(id=0, f=2.39) output = FloatingPointModel.objects.first() - self.assertEqual(2.390000104904175, output.f) # float loses precision + assert 2.390000104904175 == output.f # float loses precision FloatingPointModel.create(id=0, f=3.4028234663852886e+38, d=2.39) output = FloatingPointModel.objects.first() - self.assertEqual(3.4028234663852886e+38, output.f) - self.assertEqual(2.39, output.d) # double retains precision + assert 3.4028234663852886e+38 == output.f + assert 2.39 == output.d # double retains precision FloatingPointModel.create(id=0, d=3.4028234663852886e+38) output = FloatingPointModel.objects.first() - self.assertEqual(3.4028234663852886e+38, output.d) + assert 3.4028234663852886e+38 == output.d class TestMultiKeyModel(Model): @@ -331,11 +330,11 @@ def test_deleting_only_deletes_one_object(self): for i in range(5): TestMultiKeyModel.create(partition=partition, cluster=i, count=i, text=str(i)) - self.assertTrue(TestMultiKeyModel.filter(partition=partition).count() == 5) + assert TestMultiKeyModel.filter(partition=partition).count() == 5 TestMultiKeyModel.get(partition=partition, cluster=0).delete() - self.assertTrue(TestMultiKeyModel.filter(partition=partition).count() == 4) + assert TestMultiKeyModel.filter(partition=partition).count() == 4 TestMultiKeyModel.filter(partition=partition).delete() @@ -370,8 +369,8 @@ def test_vanilla_update(self): self.instance.save() check = TestMultiKeyModel.get(partition=self.instance.partition, cluster=self.instance.cluster) - self.assertTrue(check.count == 5) - self.assertTrue(check.text == 'happy') + assert check.count == 5 + assert check.text == 'happy' def test_deleting_only(self): self.instance.count = None @@ -379,79 +378,79 @@ def test_deleting_only(self): self.instance.save() check = TestMultiKeyModel.get(partition=self.instance.partition, cluster=self.instance.cluster) - self.assertTrue(check.count is None) - self.assertTrue(check.text is None) + assert check.count is None + assert check.text is None def test_get_changed_columns(self): - self.assertTrue(self.instance.get_changed_columns() == []) + assert self.instance.get_changed_columns() == [] self.instance.count = 1 changes = self.instance.get_changed_columns() - self.assertTrue(len(changes) == 1) - self.assertTrue(changes == ['count']) + assert len(changes) == 1 + assert changes == ['count'] self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) + assert self.instance.get_changed_columns() == [] def test_previous_value_tracking_of_persisted_instance(self): # Check initial internal states. - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value == 0) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value == 0 # Change value and check internal states. self.instance.count = 1 - self.assertTrue(self.instance.get_changed_columns() == ['count']) - self.assertTrue(self.instance._values['count'].previous_value == 0) + assert self.instance.get_changed_columns() == ['count'] + assert self.instance._values['count'].previous_value == 0 # Internal states should be updated on save. self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value == 1) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value == 1 # Change value twice. self.instance.count = 2 - self.assertTrue(self.instance.get_changed_columns() == ['count']) - self.assertTrue(self.instance._values['count'].previous_value == 1) + assert self.instance.get_changed_columns() == ['count'] + assert self.instance._values['count'].previous_value == 1 self.instance.count = 3 - self.assertTrue(self.instance.get_changed_columns() == ['count']) - self.assertTrue(self.instance._values['count'].previous_value == 1) + assert self.instance.get_changed_columns() == ['count'] + assert self.instance._values['count'].previous_value == 1 # Internal states updated on save. self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value == 3) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value == 3 # Change value and reset it. self.instance.count = 2 - self.assertTrue(self.instance.get_changed_columns() == ['count']) - self.assertTrue(self.instance._values['count'].previous_value == 3) + assert self.instance.get_changed_columns() == ['count'] + assert self.instance._values['count'].previous_value == 3 self.instance.count = 3 - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value == 3) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value == 3 # Nothing to save: values in initial conditions. self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value == 3) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value == 3 # Change Multiple values self.instance.count = 4 self.instance.text = "changed" - self.assertTrue(len(self.instance.get_changed_columns()) == 2) - self.assertTrue('text' in self.instance.get_changed_columns()) - self.assertTrue('count' in self.instance.get_changed_columns()) + assert len(self.instance.get_changed_columns()) == 2 + assert 'text' in self.instance.get_changed_columns() + assert 'count' in self.instance.get_changed_columns() self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) + assert self.instance.get_changed_columns() == [] # Reset Multiple Values self.instance.count = 5 self.instance.text = "changed" - self.assertTrue(self.instance.get_changed_columns() == ['count']) + assert self.instance.get_changed_columns() == ['count'] self.instance.text = "changed2" - self.assertTrue(len(self.instance.get_changed_columns()) == 2) - self.assertTrue('text' in self.instance.get_changed_columns()) - self.assertTrue('count' in self.instance.get_changed_columns()) + assert len(self.instance.get_changed_columns()) == 2 + assert 'text' in self.instance.get_changed_columns() + assert 'count' in self.instance.get_changed_columns() self.instance.count = 4 self.instance.text = "changed" - self.assertTrue(self.instance.get_changed_columns() == []) + assert self.instance.get_changed_columns() == [] def test_previous_value_tracking_on_instantiation(self): self.instance = TestMultiKeyModel( @@ -461,30 +460,30 @@ def test_previous_value_tracking_on_instantiation(self): text='happy') # Columns of instances not persisted yet should be marked as changed. - self.assertTrue(set(self.instance.get_changed_columns()) == set([ - 'partition', 'cluster', 'count', 'text'])) - self.assertTrue(self.instance._values['partition'].previous_value is None) - self.assertTrue(self.instance._values['cluster'].previous_value is None) - self.assertTrue(self.instance._values['count'].previous_value is None) - self.assertTrue(self.instance._values['text'].previous_value is None) + assert set(self.instance.get_changed_columns()) == set([ + 'partition', 'cluster', 'count', 'text']) + assert self.instance._values['partition'].previous_value is None + assert self.instance._values['cluster'].previous_value is None + assert self.instance._values['count'].previous_value is None + assert self.instance._values['text'].previous_value is None # Value changes doesn't affect internal states. self.instance.count = 1 - self.assertTrue('count' in self.instance.get_changed_columns()) - self.assertTrue(self.instance._values['count'].previous_value is None) + assert 'count' in self.instance.get_changed_columns() + assert self.instance._values['count'].previous_value is None self.instance.count = 2 - self.assertTrue('count' in self.instance.get_changed_columns()) - self.assertTrue(self.instance._values['count'].previous_value is None) + assert 'count' in self.instance.get_changed_columns() + assert self.instance._values['count'].previous_value is None # Value reset is properly tracked. self.instance.count = None - self.assertTrue('count' not in self.instance.get_changed_columns()) - self.assertTrue(self.instance._values['count'].previous_value is None) + assert 'count' not in self.instance.get_changed_columns() + assert self.instance._values['count'].previous_value is None self.instance.save() - self.assertTrue(self.instance.get_changed_columns() == []) - self.assertTrue(self.instance._values['count'].previous_value is None) - self.assertTrue(self.instance.count is None) + assert self.instance.get_changed_columns() == [] + assert self.instance._values['count'].previous_value is None + assert self.instance.count is None def test_previous_value_tracking_on_instantiation_with_default(self): @@ -503,31 +502,31 @@ class TestDefaultValueTracking(Model): int3=7777, int5=5555) - self.assertEqual(instance.id, 1) - self.assertEqual(instance.int1, 9999) - self.assertEqual(instance.int2, 456) - self.assertEqual(instance.int3, 7777) - self.assertIsNotNone(instance.int4) - self.assertIsInstance(instance.int4, int) - self.assertGreaterEqual(instance.int4, 0) - self.assertLessEqual(instance.int4, 1000) - self.assertEqual(instance.int5, 5555) - self.assertTrue(instance.int6 is None) + assert instance.id == 1 + assert instance.int1 == 9999 + assert instance.int2 == 456 + assert instance.int3 == 7777 + assert instance.int4 is not None + assert isinstance(instance.int4, int) + assert instance.int4 >= 0 + assert instance.int4 <= 1000 + assert instance.int5 == 5555 + assert instance.int6 is None # All previous values are unset as the object hasn't been persisted # yet. - self.assertTrue(instance._values['id'].previous_value is None) - self.assertTrue(instance._values['int1'].previous_value is None) - self.assertTrue(instance._values['int2'].previous_value is None) - self.assertTrue(instance._values['int3'].previous_value is None) - self.assertTrue(instance._values['int4'].previous_value is None) - self.assertTrue(instance._values['int5'].previous_value is None) - self.assertTrue(instance._values['int6'].previous_value is None) + assert instance._values['id'].previous_value is None + assert instance._values['int1'].previous_value is None + assert instance._values['int2'].previous_value is None + assert instance._values['int3'].previous_value is None + assert instance._values['int4'].previous_value is None + assert instance._values['int5'].previous_value is None + assert instance._values['int6'].previous_value is None # All explicitely set columns, and those with default values are # flagged has changed. - self.assertTrue(set(instance.get_changed_columns()) == set([ - 'id', 'int1', 'int3', 'int5'])) + assert set(instance.get_changed_columns()) == set([ + 'id', 'int1', 'int3', 'int5']) def test_save_to_none(self): """ @@ -554,20 +553,20 @@ def test_save_to_none(self): text_set=text_set, text_map=text_map) initial.save() current = TestModelSave.objects.get(partition=partition, cluster=cluster) - self.assertEqual(current.text, text) - self.assertEqual(current.text_list, text_list) - self.assertEqual(current.text_set, text_set) - self.assertEqual(current.text_map, text_map) + assert current.text == text + assert current.text_list == text_list + assert current.text_set == text_set + assert current.text_map == text_map next = TestModelSave(partition=partition, cluster=cluster, text=None, text_list=None, text_set=None, text_map=None) next.save() current = TestModelSave.objects.get(partition=partition, cluster=cluster) - self.assertEqual(current.text, None) - self.assertEqual(current.text_list, []) - self.assertEqual(current.text_set, set()) - self.assertEqual(current.text_map, {}) + assert current.text == None + assert current.text_list == [] + assert current.text_set == set() + assert current.text_map == {} def test_none_filter_fails(): @@ -602,28 +601,28 @@ def test_success_case(self): # object hasn't been saved, # shouldn't be able to update - self.assertTrue(not tm._is_persisted) - self.assertTrue(not tm._can_update()) + assert not tm._is_persisted + assert not tm._can_update() tm.save() # object has been saved, # should be able to update - self.assertTrue(tm._is_persisted) - self.assertTrue(tm._can_update()) + assert tm._is_persisted + assert tm._can_update() tm.count = 200 # primary keys haven't changed, # should still be able to update - self.assertTrue(tm._can_update()) + assert tm._can_update() tm.save() tm.id = uuid4() # primary keys have changed, # should not be able to update - self.assertTrue(not tm._can_update()) + assert not tm._can_update() class IndexDefinitionModel(Model): @@ -656,9 +655,9 @@ def test_reserved_cql_words_can_be_used_as_column_names(self): model2 = ReservedWordModel.filter(token='1') - self.assertTrue(len(model2) == 1) - self.assertTrue(model1.token == model2[0].token) - self.assertTrue(model1.insert == model2[0].insert) + assert len(model2) == 1 + assert model1.token == model2[0].token + assert model1.insert == model2[0].insert class TestQueryModel(Model): @@ -697,14 +696,14 @@ def test_query_with_date(self): day = date(2013, 11, 26) obj = TestQueryModel.create(test_id=uid, date=day, description=u'foo') - self.assertEqual(obj.description, u'foo') + assert obj.description == u'foo' inst = TestQueryModel.filter( TestQueryModel.test_id == uid, TestQueryModel.date == day).limit(1).first() - self.assertTrue(inst.test_id == uid) - self.assertTrue(inst.date == day) + assert inst.test_id == uid + assert inst.date == day class BasicModelNoRouting(Model): @@ -774,16 +773,16 @@ def test_routing_key_is_ignored(self): mrk = BasicModelNoRouting._routing_key_from_values([1], self.session.cluster.protocol_version) simple = SimpleStatement("") simple.routing_key = mrk - self.assertNotEqual(bound.routing_key, simple.routing_key) + assert bound.routing_key != simple.routing_key # Verify that basic create, update and delete work with no routing key t = BasicModelNoRouting.create(k=2, v=3) t.update(v=4).save() f = BasicModelNoRouting.objects.filter(k=2).first() - self.assertEqual(t, f) + assert t == f t.delete() - self.assertEqual(BasicModelNoRouting.objects.count(), 0) + assert BasicModelNoRouting.objects.count() == 0 def test_routing_key_generation_basic(self): @@ -806,7 +805,7 @@ def test_routing_key_generation_basic(self): mrk = BasicModel._routing_key_from_values([1], self.session.cluster.protocol_version) simple = SimpleStatement("") simple.routing_key = mrk - self.assertEqual(bound.routing_key, simple.routing_key) + assert bound.routing_key == simple.routing_key def test_routing_key_generation_multi(self): """ @@ -827,7 +826,7 @@ def test_routing_key_generation_multi(self): mrk = BasicModelMulti._routing_key_from_values([1, 2], self.session.cluster.protocol_version) simple = SimpleStatement("") simple.routing_key = mrk - self.assertEqual(bound.routing_key, simple.routing_key) + assert bound.routing_key == simple.routing_key def test_routing_key_generation_complex(self): """ @@ -853,7 +852,7 @@ def test_routing_key_generation_complex(self): mrk = ComplexModelRouting._routing_key_from_values([partition, cluster, text, float], self.session.cluster.protocol_version) simple = SimpleStatement("") simple.routing_key = mrk - self.assertEqual(bound.routing_key, simple.routing_key) + assert bound.routing_key == simple.routing_key def test_partition_key_index(self): """ @@ -899,7 +898,7 @@ def _check_partition_value_generation(self, model, state, reverse=False): # Those specified in the models partition field for indx, value in enumerate(state.partition_key_values(model._partition_key_index)): name = res.get(value) - self.assertEqual(indx, model._partition_key_index.get(name)) + assert indx == model._partition_key_index.get(name) def test_none_filter_fails(): diff --git a/tests/integration/cqlengine/model/test_polymorphism.py b/tests/integration/cqlengine/model/test_polymorphism.py index f27703367d..a37b499df6 100644 --- a/tests/integration/cqlengine/model/test_polymorphism.py +++ b/tests/integration/cqlengine/model/test_polymorphism.py @@ -20,20 +20,21 @@ from cassandra.cqlengine.connection import get_session from tests.integration.cqlengine.base import BaseCassEngTestCase from cassandra.cqlengine import management +import pytest class TestInheritanceClassConstruction(BaseCassEngTestCase): def test_multiple_discriminator_value_failure(self): """ Tests that defining a model with more than one discriminator column fails """ - with self.assertRaises(models.ModelDefinitionException): + with pytest.raises(models.ModelDefinitionException): class M(models.Model): partition = columns.Integer(primary_key=True) type1 = columns.Integer(discriminator_column=True) type2 = columns.Integer(discriminator_column=True) def test_no_discriminator_column_failure(self): - with self.assertRaises(models.ModelDefinitionException): + with pytest.raises(models.ModelDefinitionException): class M(models.Model): __discriminator_value__ = 1 @@ -86,7 +87,7 @@ class M1(Base): assert Base.column_family_name() == M1.column_family_name() def test_collection_columns_cant_be_discriminator_column(self): - with self.assertRaises(models.ModelDefinitionException): + with pytest.raises(models.ModelDefinitionException): class Base(models.Model): partition = columns.Integer(primary_key=True) @@ -124,7 +125,7 @@ def tearDownClass(cls): management.drop_table(Inherit2) def test_saving_base_model_fails(self): - with self.assertRaises(models.PolymorphicModelException): + with pytest.raises(models.PolymorphicModelException): InheritBase.create() def test_saving_subclass_saves_disc_value(self): @@ -154,7 +155,7 @@ def test_delete_on_subclass_does_not_include_disc_value(self): # not sure how we would even get here if it was in there # since the CQL would fail. - self.assertNotIn("row_type", m.call_args[0][0].query_string) + assert "row_type" not in m.call_args[0][0].query_string class UnindexedInheritBase(models.Model): @@ -210,9 +211,9 @@ def test_subclassed_model_results_work_properly(self): assert len(list(UnindexedInherit2.objects(partition=p1.partition, cluster__in=[p2.cluster, p3.cluster]))) == 2 def test_conflicting_type_results(self): - with self.assertRaises(models.PolymorphicModelException): + with pytest.raises(models.PolymorphicModelException): list(UnindexedInherit1.objects(partition=self.p1.partition)) - with self.assertRaises(models.PolymorphicModelException): + with pytest.raises(models.PolymorphicModelException): list(UnindexedInherit2.objects(partition=self.p1.partition)) @@ -251,5 +252,5 @@ def tearDownClass(cls): management.drop_table(IndexedInherit2) def test_success_case(self): - self.assertEqual(len(list(IndexedInherit1.objects(partition=self.p1.partition))), 1) - self.assertEqual(len(list(IndexedInherit2.objects(partition=self.p1.partition))), 1) + assert len(list(IndexedInherit1.objects(partition=self.p1.partition))) == 1 + assert len(list(IndexedInherit2.objects(partition=self.p1.partition))) == 1 diff --git a/tests/integration/cqlengine/model/test_udts.py b/tests/integration/cqlengine/model/test_udts.py index 7063df8caa..80f1b9693f 100644 --- a/tests/integration/cqlengine/model/test_udts.py +++ b/tests/integration/cqlengine/model/test_udts.py @@ -28,6 +28,7 @@ from tests.integration import PROTOCOL_VERSION from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import DEFAULT_KEYSPACE +import pytest class User(UserType): @@ -75,8 +76,8 @@ class User(UserType): sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John") - self.assertEqual(42, user.age) - self.assertEqual("John", user.name) + assert 42 == user.age + assert "John" == user.name # Add a field class User(UserType): @@ -88,9 +89,9 @@ class User(UserType): user = User(age=42) user["name"] = "John" user["gender"] = "male" - self.assertEqual(42, user.age) - self.assertEqual("John", user.name) - self.assertEqual("male", user.gender) + assert 42 == user.age + assert "John" == user.name + assert "male" == user.gender # Remove a field class User(UserType): @@ -99,7 +100,7 @@ class User(UserType): sync_type(DEFAULT_KEYSPACE, User) user = User(age=42, name="John", gender="male") - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): user.gender def test_can_insert_udts(self): @@ -110,13 +111,13 @@ def test_can_insert_udts(self): user = User(age=42, name="John") UserModel.create(id=0, info=user) - self.assertEqual(1, UserModel.objects.count()) + assert 1 == UserModel.objects.count() john = UserModel.objects.first() - self.assertEqual(0, john.id) - self.assertTrue(type(john.info) is User) - self.assertEqual(42, john.info.age) - self.assertEqual("John", john.info.name) + assert 0 == john.id + assert type(john.info) is User + assert 42 == john.info.age + assert "John" == john.info.name def test_can_update_udts(self): sync_table(UserModel) @@ -126,15 +127,15 @@ def test_can_update_udts(self): created_user = UserModel.create(id=0, info=user) john_info = UserModel.objects.first().info - self.assertEqual(42, john_info.age) - self.assertEqual("John", john_info.name) + assert 42 == john_info.age + assert "John" == john_info.name created_user.info = User(age=22, name="Mary") created_user.update() mary_info = UserModel.objects.first().info - self.assertEqual(22, mary_info["age"]) - self.assertEqual("Mary", mary_info["name"]) + assert 22 == mary_info["age"] + assert "Mary" == mary_info["name"] def test_can_update_udts_with_nones(self): sync_table(UserModel) @@ -144,14 +145,14 @@ def test_can_update_udts_with_nones(self): created_user = UserModel.create(id=0, info=user) john_info = UserModel.objects.first().info - self.assertEqual(42, john_info.age) - self.assertEqual("John", john_info.name) + assert 42 == john_info.age + assert "John" == john_info.name created_user.info = None created_user.update() john_info = UserModel.objects.first().info - self.assertIsNone(john_info) + assert john_info is None def test_can_create_same_udt_different_keyspaces(self): sync_type(DEFAULT_KEYSPACE, User) @@ -177,17 +178,17 @@ class UserModelGender(Model): UserModelGender.create(id=0, info=user) john_info = UserModelGender.objects.first().info - self.assertEqual(42, john_info.age) - self.assertEqual("John", john_info.name) - self.assertIsNone(john_info.gender) + assert 42 == john_info.age + assert "John" == john_info.name + assert john_info.gender is None user = UserGender(age=42) UserModelGender.create(id=0, info=user) john_info = UserModelGender.objects.first().info - self.assertEqual(42, john_info.age) - self.assertIsNone(john_info.name) - self.assertIsNone(john_info.gender) + assert 42 == john_info.age + assert john_info.name is None + assert john_info.gender is None def test_can_insert_nested_udts(self): class Depth_0(UserType): @@ -221,10 +222,10 @@ class DepthModel(Model): DepthModel.create(id=0, v_0=udts[0], v_1=udts[1], v_2=udts[2], v_3=udts[3]) output = DepthModel.objects.first() - self.assertEqual(udts[0], output.v_0) - self.assertEqual(udts[1], output.v_1) - self.assertEqual(udts[2], output.v_2) - self.assertEqual(udts[3], output.v_3) + assert udts[0] == output.v_0 + assert udts[1] == output.v_1 + assert udts[2] == output.v_2 + assert udts[3] == output.v_3 def test_can_insert_udts_with_nones(self): """ @@ -248,10 +249,10 @@ def test_can_insert_udts_with_nones(self): l=None, m=None, n=None) AllDatatypesModel.create(id=0, data=input) - self.assertEqual(1, AllDatatypesModel.objects.count()) + assert 1 == AllDatatypesModel.objects.count() output = AllDatatypesModel.objects.first().data - self.assertEqual(input, output) + assert input == output def test_can_insert_udts_with_all_datatypes(self): """ @@ -278,11 +279,11 @@ def test_can_insert_udts_with_all_datatypes(self): m=UUID('067e6162-3b6f-4ae2-a171-2470b63dff00'), n=int(str(2147483647) + '000')) AllDatatypesModel.create(id=0, data=input) - self.assertEqual(1, AllDatatypesModel.objects.count()) + assert 1 == AllDatatypesModel.objects.count() output = AllDatatypesModel.objects.first().data for i in range(ord('a'), ord('a') + 14): - self.assertEqual(input[chr(i)], output[chr(i)]) + assert input[chr(i)] == output[chr(i)] def test_can_insert_udts_protocol_v4_datatypes(self): """ @@ -320,11 +321,11 @@ class Allv4DatatypesModel(Model): input = Allv4Datatypes(a=Date(date(1970, 1, 1)), b=32523, c=Time(time(16, 47, 25, 7)), d=123) Allv4DatatypesModel.create(id=0, data=input) - self.assertEqual(1, Allv4DatatypesModel.objects.count()) + assert 1 == Allv4DatatypesModel.objects.count() output = Allv4DatatypesModel.objects.first().data for i in range(ord('a'), ord('a') + 3): - self.assertEqual(input[chr(i)], output[chr(i)]) + assert input[chr(i)] == output[chr(i)] def test_nested_udts_inserts(self): """ @@ -364,9 +365,9 @@ class Container(Model): Container.create(id=UUID('FE2B4360-28C6-11E2-81C1-0800200C9A66'), names=names) # Validate input and output matches - self.assertEqual(1, Container.objects.count()) + assert 1 == Container.objects.count() names_output = Container.objects.first().names - self.assertEqual(names_output, names) + assert names_output == names def test_udts_with_unicode(self): """ @@ -407,8 +408,8 @@ def test_register_default_keyspace(self): # None emulating no model and no default keyspace before connecting connection.udt_by_keyspace.clear() User.register_for_keyspace(None) - self.assertEqual(len(connection.udt_by_keyspace), 1) - self.assertIn(None, connection.udt_by_keyspace) + assert len(connection.udt_by_keyspace) == 1 + assert None in connection.udt_by_keyspace # register should be with default keyspace, not None cluster = Mock() @@ -443,9 +444,9 @@ class TheModel(Model): type_fields = (db_field_different.age.column, db_field_different.name.column) - self.assertEqual(len(type_meta.field_names), len(type_fields)) + assert len(type_meta.field_names) == len(type_fields) for f in type_fields: - self.assertIn(f.db_field_name, type_meta.field_names) + assert f.db_field_name in type_meta.field_names id = 0 age = 42 @@ -453,17 +454,17 @@ class TheModel(Model): info = db_field_different(age=age, name=name) TheModel.create(id=id, info=info) - self.assertEqual(1, TheModel.objects.count()) + assert 1 == TheModel.objects.count() john = TheModel.objects.first() - self.assertEqual(john.id, id) + assert john.id == id info = john.info - self.assertIsInstance(info, db_field_different) - self.assertEqual(info.age, age) - self.assertEqual(info.name, name) + assert isinstance(info, db_field_different) + assert info.age == age + assert info.name == name # also excercise the db_Field mapping - self.assertEqual(info.a, age) - self.assertEqual(info.n, name) + assert info.a == age + assert info.n == name def test_db_field_overload(self): """ @@ -478,12 +479,12 @@ def test_db_field_overload(self): @test_category data_types:udt """ - with self.assertRaises(UserTypeDefinitionException): + with pytest.raises(UserTypeDefinitionException): class something_silly(UserType): first_col = columns.Integer() second_col = columns.Text(db_field='first_col') - with self.assertRaises(UserTypeDefinitionException): + with pytest.raises(UserTypeDefinitionException): class something_silly_2(UserType): first_col = columns.Integer(db_field="second_col") second_col = columns.Text() @@ -493,7 +494,7 @@ def test_set_udt_fields(self): u = User() u.age = 20 - self.assertEqual(20, u.age) + assert 20 == u.age def test_default_values(self): """ @@ -526,10 +527,10 @@ class OuterModel(Model): t.nested = [NestedUdt(something='test')] t.simple = NestedUdt(something="") t.save() - self.assertIsNotNone(t.nested[0].test_id) - self.assertEqual(t.nested[0].default_text, "default text") - self.assertIsNotNone(t.simple.test_id) - self.assertEqual(t.simple.default_text, "default text") + assert t.nested[0].test_id is not None + assert t.nested[0].default_text == "default text" + assert t.simple.test_id is not None + assert t.simple.default_text == "default text" def test_udt_validate(self): """ @@ -556,7 +557,7 @@ class UserModelValidate(Model): user = UserValidate(age=1, name="Robert") item = UserModelValidate(id=1, info=user) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): item.save() def test_udt_validate_with_default(self): @@ -584,5 +585,5 @@ class UserModelValidateDefault(Model): user = UserValidateDefault(age=1) item = UserModelValidateDefault(id=1, info=user) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): item.save() diff --git a/tests/integration/cqlengine/model/test_updates.py b/tests/integration/cqlengine/model/test_updates.py index 718c651880..c64df8fdcc 100644 --- a/tests/integration/cqlengine/model/test_updates.py +++ b/tests/integration/cqlengine/model/test_updates.py @@ -23,6 +23,7 @@ from cassandra.cqlengine import columns from cassandra.cqlengine.management import sync_table, drop_table from cassandra.cqlengine.usertype import UserType +import pytest class TestUpdateModel(Model): __test__ = False @@ -60,18 +61,18 @@ def test_update_model(self): # database should reflect both updates m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) - self.assertEqual(m2.count, m1.count) - self.assertEqual(m2.text, m0.text) + assert m2.count == m1.count + assert m2.text == m0.text #This shouldn't raise a Validation error as the PR is not changing m0.update(partition=m0.partition, cluster=m0.cluster) #Assert a ValidationError is risen if the PR changes - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): m0.update(partition=m0.partition, cluster=20) # Assert a ValidationError is risen if the columns doesn't exist - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): m0.update(invalid_column=20) def test_update_values(self): @@ -85,12 +86,12 @@ def test_update_values(self): # update the text, and call update m0.update(text='monkey land') - self.assertEqual(m0.text, 'monkey land') + assert m0.text == 'monkey land' # database should reflect both updates m2 = TestUpdateModel.get(partition=m0.partition, cluster=m0.cluster) - self.assertEqual(m2.count, m1.count) - self.assertEqual(m2.text, m0.text) + assert m2.count == m1.count + assert m2.text == m0.text def test_noop_model_direct_update(self): """ Tests that calling update on a model with no changes will do nothing. """ @@ -139,13 +140,13 @@ def test_noop_model_assignation_update(self): def test_invalid_update_kwarg(self): """ tests that passing in a kwarg to the update method that isn't a column will fail """ m0 = TestUpdateModel.create(count=5, text='monkey') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): m0.update(numbers=20) def test_primary_key_update_failure(self): """ tests that attempting to update the value of a primary key will fail """ m0 = TestUpdateModel.create(count=5, text='monkey') - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): m0.update(partition=uuid4()) @@ -206,15 +207,13 @@ def test_value_override_with_default(self): initial = ModelWithDefault(id=1, mf={0: 0}, dummy=0, udt=first_udt, udt_default=first_udt) initial.save() - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': 0, 'mf': {0: 0}, "udt": first_udt, "udt_default": first_udt} second_udt = UDT(age=1, mf={3: 3}, dummy_udt=12) second = ModelWithDefault(id=1) second.update(mf={0: 1}, udt=second_udt) - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': 0, 'mf': {0: 1}, "udt": second_udt, "udt_default": first_udt} def test_value_is_written_if_is_default(self): """ @@ -231,8 +230,7 @@ def test_value_is_written_if_is_default(self): initial.udt_default = self.udt_default initial.update() - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.udt_default}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': 42, 'mf': {0: 0}, "udt": None, "udt_default": self.udt_default} def test_null_update_is_respected(self): """ @@ -253,8 +251,7 @@ def test_null_update_is_respected(self): updated_udt = UDT(age=1, mf={2:2}, dummy_udt=None) obj.update(dummy=None, udt_default=updated_udt) - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': None, 'mf': {0: 0}, "udt": None, "udt_default": updated_udt} def test_only_set_values_is_updated(self): """ @@ -276,8 +273,7 @@ def test_only_set_values_is_updated(self): item.udt, item.udt_default = udt, udt_default item.save() - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': None, 'mf': {1: 2}, "udt": udt, "udt_default": udt_default} def test_collections(self): """ @@ -296,8 +292,7 @@ def test_collections(self): udt, udt_default = UDT(age=1, mf={2: 1}), UDT(age=1, mf={2: 1}) item.update(mf={2:1}, udt=udt, udt_default=udt_default) - self.assertEqual(ModelWithDefault.get()._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefault.get()._as_dict() == {'id': 1, 'dummy': 1, 'mf': {2: 1}, "udt": udt, "udt_default": udt_default} def test_collection_with_default(self): """ @@ -314,38 +309,31 @@ def test_collection_with_default(self): udt, udt_default = UDT(age=1, mf={6: 6}), UDT(age=1, mf={6: 6}) item = ModelWithDefaultCollection.create(id=1, mf={1: 1}, dummy=1, udt=udt, udt_default=udt_default).save() - self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefaultCollection.objects.get(id=1)._as_dict() == {'id': 1, 'dummy': 1, 'mf': {1: 1}, "udt": udt, "udt_default": udt_default} udt, udt_default = UDT(age=1, mf={5: 5}), UDT(age=1, mf={5: 5}) item.update(mf={2: 2}, udt=udt, udt_default=udt_default) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefaultCollection.objects.get(id=1)._as_dict() == {'id': 1, 'dummy': 1, 'mf': {2: 2}, "udt": udt, "udt_default": udt_default} udt, udt_default = UDT(age=1, mf=None), UDT(age=1, mf=None) expected_udt, expected_udt_default = UDT(age=1, mf={}), UDT(age=1, mf={}) item.update(mf=None, udt=udt, udt_default=udt_default) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=1)._as_dict(), - {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default}) + assert ModelWithDefaultCollection.objects.get(id=1)._as_dict() == {'id': 1, 'dummy': 1, 'mf': {}, "udt": expected_udt, "udt_default": expected_udt_default} udt_default = UDT(age=1, mf={2:2}, dummy_udt=42) item = ModelWithDefaultCollection.create(id=2, dummy=2) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default}) + assert ModelWithDefaultCollection.objects.get(id=2)._as_dict() == {'id': 2, 'dummy': 2, 'mf': {2: 2}, "udt": None, "udt_default": udt_default} udt, udt_default = UDT(age=1, mf={1: 1, 6: 6}), UDT(age=1, mf={1: 1, 6: 6}) item.update(mf={1: 1, 4: 4}, udt=udt, udt_default=udt_default) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefaultCollection.objects.get(id=2)._as_dict() == {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default} item.update(udt_default=None) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": None}) + assert ModelWithDefaultCollection.objects.get(id=2)._as_dict() == {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": None} udt_default = UDT(age=1, mf={2:2}) item.update(udt_default=udt_default) - self.assertEqual(ModelWithDefaultCollection.objects.get(id=2)._as_dict(), - {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefaultCollection.objects.get(id=2)._as_dict() == {'id': 2, 'dummy': 2, 'mf': {1: 1, 4: 4}, "udt": udt, "udt_default": udt_default} def test_udt_to_python(self): @@ -370,5 +358,4 @@ def test_udt_to_python(self): item.update(udt=user_to_update) udt, udt_default = UDT(time_col=10), UDT(age=1, mf={2:2}) - self.assertEqual(ModelWithDefault.objects.get(id=1)._as_dict(), - {'id': 1, 'dummy': 42, 'mf': {}, "udt": udt, "udt_default": udt_default}) + assert ModelWithDefault.objects.get(id=1)._as_dict() == {'id': 1, 'dummy': 42, 'mf': {}, "udt": udt, "udt_default": udt_default} diff --git a/tests/integration/cqlengine/model/test_value_lists.py b/tests/integration/cqlengine/model/test_value_lists.py index 8fd7f4b392..cdab57ea38 100644 --- a/tests/integration/cqlengine/model/test_value_lists.py +++ b/tests/integration/cqlengine/model/test_value_lists.py @@ -57,7 +57,7 @@ def test_clustering_order(self): values = list(TestModel.objects.values_list('clustering_key', flat=True)) # [19L, 18L, 17L, 16L, 15L, 14L, 13L, 12L, 11L, 10L, 9L, 8L, 7L, 6L, 5L, 4L, 3L, 2L, 1L, 0L] - self.assertEqual(values, sorted(items, reverse=True)) + assert values == sorted(items, reverse=True) def test_clustering_order_more_complex(self): """ @@ -72,6 +72,6 @@ def test_clustering_order_more_complex(self): values = list(TestClusteringComplexModel.objects.values_list('some_value', flat=True)) - self.assertEqual([2] * 20, values) + assert [2] * 20 == values drop_table(TestClusteringComplexModel) diff --git a/tests/integration/cqlengine/operators/__init__.py b/tests/integration/cqlengine/operators/__init__.py index 05a41c46fd..45690e9448 100644 --- a/tests/integration/cqlengine/operators/__init__.py +++ b/tests/integration/cqlengine/operators/__init__.py @@ -15,6 +15,6 @@ from cassandra.cqlengine.operators import BaseWhereOperator -def check_lookup(test_case, symbol, expected): +def check_lookup(symbol, expected): op = BaseWhereOperator.get_operator(symbol) - test_case.assertEqual(op, expected) + assert op == expected diff --git a/tests/integration/cqlengine/operators/test_where_operators.py b/tests/integration/cqlengine/operators/test_where_operators.py index e04a377c88..c7e30b8905 100644 --- a/tests/integration/cqlengine/operators/test_where_operators.py +++ b/tests/integration/cqlengine/operators/test_where_operators.py @@ -26,6 +26,7 @@ from tests.integration.cqlengine.base import TestQueryUpdateModel, BaseCassEngTestCase from tests.integration.cqlengine.operators import check_lookup from tests.integration import greaterthanorequalcass30 +import pytest class TestWhereOperators(unittest.TestCase): @@ -33,27 +34,27 @@ class TestWhereOperators(unittest.TestCase): def test_symbol_lookup(self): """ tests where symbols are looked up properly """ - check_lookup(self, 'EQ', EqualsOperator) - check_lookup(self, 'NE', NotEqualsOperator) - check_lookup(self, 'IN', InOperator) - check_lookup(self, 'GT', GreaterThanOperator) - check_lookup(self, 'GTE', GreaterThanOrEqualOperator) - check_lookup(self, 'LT', LessThanOperator) - check_lookup(self, 'LTE', LessThanOrEqualOperator) - check_lookup(self, 'CONTAINS', ContainsOperator) - check_lookup(self, 'LIKE', LikeOperator) + check_lookup('EQ', EqualsOperator) + check_lookup('NE', NotEqualsOperator) + check_lookup('IN', InOperator) + check_lookup('GT', GreaterThanOperator) + check_lookup('GTE', GreaterThanOrEqualOperator) + check_lookup('LT', LessThanOperator) + check_lookup('LTE', LessThanOrEqualOperator) + check_lookup('CONTAINS', ContainsOperator) + check_lookup('LIKE', LikeOperator) def test_operator_rendering(self): """ tests symbols are rendered properly """ - self.assertEqual("=", str(EqualsOperator())) - self.assertEqual("!=", str(NotEqualsOperator())) - self.assertEqual("IN", str(InOperator())) - self.assertEqual(">", str(GreaterThanOperator())) - self.assertEqual(">=", str(GreaterThanOrEqualOperator())) - self.assertEqual("<", str(LessThanOperator())) - self.assertEqual("<=", str(LessThanOrEqualOperator())) - self.assertEqual("CONTAINS", str(ContainsOperator())) - self.assertEqual("LIKE", str(LikeOperator())) + assert "=" == str(EqualsOperator()) + assert "!=" == str(NotEqualsOperator()) + assert "IN" == str(InOperator()) + assert ">" == str(GreaterThanOperator()) + assert ">=" == str(GreaterThanOrEqualOperator()) + assert "<" == str(LessThanOperator()) + assert "<=" == str(LessThanOrEqualOperator()) + assert "CONTAINS" == str(ContainsOperator()) + assert "LIKE" == str(LikeOperator()) class TestIsNotNull(BaseCassEngTestCase): @@ -68,21 +69,15 @@ def test_is_not_null_to_cql(self): @test_category cqlengine """ - check_lookup(self, 'IS NOT NULL', IsNotNullOperator) + check_lookup('IS NOT NULL', IsNotNullOperator) # The * is not expanded because there are no referred fields - self.assertEqual( - str(TestQueryUpdateModel.filter(IsNotNull("text")).limit(2)), - 'SELECT * FROM cqlengine_test.test_query_update_model WHERE "text" IS NOT NULL LIMIT 2' - ) + assert str(TestQueryUpdateModel.filter(IsNotNull("text")).limit(2)) == 'SELECT * FROM cqlengine_test.test_query_update_model WHERE "text" IS NOT NULL LIMIT 2' # We already know partition so cqlengine doesn't query for it - self.assertEqual( - str(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())), - ('SELECT "cluster", "count", "text", "text_set", ' - '"text_list", "text_map", "bin_map" FROM cqlengine_test.test_query_update_model ' - 'WHERE "text" IS NOT NULL AND "partition" = %(0)s LIMIT 10000') - ) + assert str(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())) == ('SELECT "cluster", "count", "text", "text_set", ' + '"text_list", "text_map", "bin_map" FROM cqlengine_test.test_query_update_model ' + 'WHERE "text" IS NOT NULL AND "partition" = %(0)s LIMIT 10000') @greaterthanorequalcass30 def test_is_not_null_execution(self): @@ -102,8 +97,8 @@ def test_is_not_null_execution(self): self.addCleanup(drop_table, TestQueryUpdateModel) # Raises InvalidRequest instead of dse.protocol.SyntaxException - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(TestQueryUpdateModel.filter(IsNotNull("text"))) - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(TestQueryUpdateModel.filter(IsNotNull("text"), partition=uuid4())) diff --git a/tests/integration/cqlengine/query/test_batch_query.py b/tests/integration/cqlengine/query/test_batch_query.py index a69f19d811..512a580154 100644 --- a/tests/integration/cqlengine/query/test_batch_query.py +++ b/tests/integration/cqlengine/query/test_batch_query.py @@ -24,6 +24,7 @@ from cassandra.cluster import Session from cassandra.query import BatchType as cassandra_BatchType from cassandra.cqlengine.query import BatchType as cqlengine_BatchType +import pytest class TestMultiKeyModel(Model): @@ -70,7 +71,7 @@ def test_insert_success_case(self): b = BatchQuery() inst = TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4') - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) b.execute() @@ -88,12 +89,12 @@ def test_update_success_case(self): inst.batch(b).save() inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) - self.assertEqual(inst2.count, 3) + assert inst2.count == 3 b.execute() inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) - self.assertEqual(inst3.count, 4) + assert inst3.count == 4 @execute_count(4) def test_delete_success_case(self): @@ -108,7 +109,7 @@ def test_delete_success_case(self): b.execute() - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) @execute_count(11) @@ -119,7 +120,7 @@ def test_context_manager(self): TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4') for i in range(5): - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=i) for i in range(5): @@ -134,9 +135,9 @@ def test_bulk_delete_success_case(self): with BatchQuery() as b: TestMultiKeyModel.objects.batch(b).filter(partition=0).delete() - self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 5) + assert TestMultiKeyModel.filter(partition=0).count() == 5 - self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 0) + assert TestMultiKeyModel.filter(partition=0).count() == 0 #cleanup for m in TestMultiKeyModel.all(): m.delete() @@ -147,10 +148,10 @@ def test_none_success_case(self): b = BatchQuery() q = TestMultiKeyModel.objects.batch(b) - self.assertEqual(q._batch, b) + assert q._batch == b q = q.batch(None) - self.assertIsNone(q._batch) + assert q._batch is None @execute_count(0) def test_dml_none_success_case(self): @@ -158,10 +159,10 @@ def test_dml_none_success_case(self): b = BatchQuery() q = DMLQuery(TestMultiKeyModel, batch=b) - self.assertEqual(q._batch, b) + assert q._batch == b q.batch(None) - self.assertIsNone(q._batch) + assert q._batch is None @execute_count(3) def test_batch_execute_on_exception_succeeds(self): @@ -170,7 +171,7 @@ def test_batch_execute_on_exception_succeeds(self): sync_table(BatchQueryLogModel) obj = BatchQueryLogModel.objects(k=1) - self.assertEqual(0, len(obj)) + assert 0 == len(obj) try: with BatchQuery(execute_on_exception=True) as b: @@ -181,7 +182,7 @@ def test_batch_execute_on_exception_succeeds(self): obj = BatchQueryLogModel.objects(k=1) # should be 1 because the batch should execute - self.assertEqual(1, len(obj)) + assert 1 == len(obj) @execute_count(2) def test_batch_execute_on_exception_skips_if_not_specified(self): @@ -190,7 +191,7 @@ def test_batch_execute_on_exception_skips_if_not_specified(self): sync_table(BatchQueryLogModel) obj = BatchQueryLogModel.objects(k=2) - self.assertEqual(0, len(obj)) + assert 0 == len(obj) try: with BatchQuery() as b: @@ -202,21 +203,21 @@ def test_batch_execute_on_exception_skips_if_not_specified(self): obj = BatchQueryLogModel.objects(k=2) # should be 0 because the batch should not execute - self.assertEqual(0, len(obj)) + assert 0 == len(obj) @execute_count(1) def test_batch_execute_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: with BatchQuery(timeout=1) as b: BatchQueryLogModel.batch(b).create(k=2, v=2) - self.assertEqual(mock_execute.call_args[-1]['timeout'], 1) + assert mock_execute.call_args[-1]['timeout'] == 1 @execute_count(1) def test_batch_execute_no_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: with BatchQuery() as b: BatchQueryLogModel.batch(b).create(k=2, v=2) - self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + assert mock_execute.call_args[-1]['timeout'] == NOT_SET class BatchTypeQueryTests(BaseCassEngTestCase): @@ -245,7 +246,7 @@ def test_cassandra_batch_type(self): TestMultiKeyModel.batch(b).create(partition=1, cluster=2) obj = TestMultiKeyModel.objects(partition=1) - self.assertEqual(2, len(obj)) + assert 2 == len(obj) with BatchQuery(batch_type=cassandra_BatchType.COUNTER) as b: CounterBatchQueryModel.batch(b).create(k=1, v=1) @@ -253,15 +254,15 @@ def test_cassandra_batch_type(self): CounterBatchQueryModel.batch(b).create(k=1, v=10) obj = CounterBatchQueryModel.objects(k=1) - self.assertEqual(1, len(obj)) - self.assertEqual(obj[0].v, 13) + assert 1 == len(obj) + assert obj[0].v == 13 with BatchQuery(batch_type=cassandra_BatchType.LOGGED) as b: TestMultiKeyModel.batch(b).create(partition=1, cluster=1) TestMultiKeyModel.batch(b).create(partition=1, cluster=2) obj = TestMultiKeyModel.objects(partition=1) - self.assertEqual(2, len(obj)) + assert 2 == len(obj) @execute_count(4) def test_cqlengine_batch_type(self): @@ -280,7 +281,7 @@ def test_cqlengine_batch_type(self): TestMultiKeyModel.batch(b).create(partition=1, cluster=2) obj = TestMultiKeyModel.objects(partition=1) - self.assertEqual(2, len(obj)) + assert 2 == len(obj) with BatchQuery(batch_type=cqlengine_BatchType.Counter) as b: CounterBatchQueryModel.batch(b).create(k=1, v=1) @@ -288,5 +289,5 @@ def test_cqlengine_batch_type(self): CounterBatchQueryModel.batch(b).create(k=1, v=10) obj = CounterBatchQueryModel.objects(k=1) - self.assertEqual(1, len(obj)) - self.assertEqual(obj[0].v, 13) + assert 1 == len(obj) + assert obj[0].v == 13 diff --git a/tests/integration/cqlengine/query/test_datetime_queries.py b/tests/integration/cqlengine/query/test_datetime_queries.py index ba1c90bb9e..e61e1bdd96 100644 --- a/tests/integration/cqlengine/query/test_datetime_queries.py +++ b/tests/integration/cqlengine/query/test_datetime_queries.py @@ -15,6 +15,7 @@ from datetime import datetime, timedelta from uuid import uuid4 from cassandra.cqlengine.functions import get_total_seconds +import pytest from tests.integration.cqlengine.base import BaseCassEngTestCase @@ -70,6 +71,5 @@ def test_datetime_precision(self): obj = DateTimeQueryTestModel.create(user=pk, day=now, data='energy cheese') load = DateTimeQueryTestModel.get(user=pk) - self.assertAlmostEqual(get_total_seconds(now - load.day), 0, 2) + assert get_total_seconds(now - load.day) == pytest.approx(0, abs=1e-2) obj.delete() - diff --git a/tests/integration/cqlengine/query/test_named.py b/tests/integration/cqlengine/query/test_named.py index 0d5ba38200..24a6802b47 100644 --- a/tests/integration/cqlengine/query/test_named.py +++ b/tests/integration/cqlengine/query/test_named.py @@ -28,6 +28,7 @@ from tests.integration import BasicSharedKeyspaceUnitTestCase, greaterthanorequalcass30, requires_collection_indexes +import pytest class TestQuerySetOperation(BaseCassEngTestCase): @@ -77,46 +78,46 @@ def test_filter_method_where_clause_generation(self): Tests the where clause creation """ query1 = self.table.objects(test_id=5) - self.assertEqual(len(query1._where), 1) + assert len(query1._where) == 1 where = query1._where[0] - self.assertEqual(where.field, 'test_id') - self.assertEqual(where.value, 5) + assert where.field == 'test_id' + assert where.value == 5 query2 = query1.filter(expected_result__gte=1) - self.assertEqual(len(query2._where), 2) + assert len(query2._where) == 2 where = query2._where[0] - self.assertEqual(where.field, 'test_id') - self.assertIsInstance(where.operator, EqualsOperator) - self.assertEqual(where.value, 5) + assert where.field == 'test_id' + assert isinstance(where.operator, EqualsOperator) + assert where.value == 5 where = query2._where[1] - self.assertEqual(where.field, 'expected_result') - self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) - self.assertEqual(where.value, 1) + assert where.field == 'expected_result' + assert isinstance(where.operator, GreaterThanOrEqualOperator) + assert where.value == 1 def test_query_expression_where_clause_generation(self): """ Tests the where clause creation """ query1 = self.table.objects(self.table.column('test_id') == 5) - self.assertEqual(len(query1._where), 1) + assert len(query1._where) == 1 where = query1._where[0] - self.assertEqual(where.field, 'test_id') - self.assertEqual(where.value, 5) + assert where.field == 'test_id' + assert where.value == 5 query2 = query1.filter(self.table.column('expected_result') >= 1) - self.assertEqual(len(query2._where), 2) + assert len(query2._where) == 2 where = query2._where[0] - self.assertEqual(where.field, 'test_id') - self.assertIsInstance(where.operator, EqualsOperator) - self.assertEqual(where.value, 5) + assert where.field == 'test_id' + assert isinstance(where.operator, EqualsOperator) + assert where.value == 5 where = query2._where[1] - self.assertEqual(where.field, 'expected_result') - self.assertIsInstance(where.operator, GreaterThanOrEqualOperator) - self.assertEqual(where.value, 1) + assert where.field == 'expected_result' + assert isinstance(where.operator, GreaterThanOrEqualOperator) + assert where.value == 1 @requires_collection_indexes class TestQuerySetCountSelectionAndIteration(BaseQuerySetUsage): @@ -265,7 +266,7 @@ def test_get_doesnotexist_exception(self): """ Tests that get calls that don't return a result raises a DoesNotExist error """ - with self.assertRaises(self.table.DoesNotExist): + with pytest.raises(self.table.DoesNotExist): self.table.objects.get(test_id=100) @execute_count(1) @@ -273,7 +274,7 @@ def test_get_multipleobjects_exception(self): """ Tests that get calls that return multiple results raise a MultipleObjectsReturned error """ - with self.assertRaises(self.table.MultipleObjectsReturned): + with pytest.raises(self.table.MultipleObjectsReturned): self.table.objects.get(test_id=1) @@ -358,17 +359,17 @@ def test_named_table_with_mv(self): key_space = NamedKeyspace(ks) mv_monthly = key_space.table("monthlyhigh") mv_all_time = key_space.table("alltimehigh") - self.assertTrue(self.check_table_size("scores", key_space, len(parameters))) - self.assertTrue(self.check_table_size("monthlyhigh", key_space, len(parameters))) - self.assertTrue(self.check_table_size("alltimehigh", key_space, len(parameters))) + assert self.check_table_size("scores", key_space, len(parameters)) + assert self.check_table_size("monthlyhigh", key_space, len(parameters)) + assert self.check_table_size("alltimehigh", key_space, len(parameters)) filtered_mv_monthly_objects = mv_monthly.objects.filter(game='Chess', year=2015, month=6) - self.assertEqual(len(filtered_mv_monthly_objects), 1) - self.assertEqual(filtered_mv_monthly_objects[0]['score'], 3500) - self.assertEqual(filtered_mv_monthly_objects[0]['user'], 'jbellis') + assert len(filtered_mv_monthly_objects) == 1 + assert filtered_mv_monthly_objects[0]['score'] == 3500 + assert filtered_mv_monthly_objects[0]['user'] == 'jbellis' filtered_mv_alltime_objects = mv_all_time.objects.filter(game='Chess') - self.assertEqual(len(filtered_mv_alltime_objects), 2) - self.assertEqual(filtered_mv_alltime_objects[0]['score'], 3500) + assert len(filtered_mv_alltime_objects) == 2 + assert filtered_mv_alltime_objects[0]['score'] == 3500 def check_table_size(self, table_name, key_space, expected_size): table = key_space.table(table_name) diff --git a/tests/integration/cqlengine/query/test_queryoperators.py b/tests/integration/cqlengine/query/test_queryoperators.py index fbf666cf21..b9e9356b06 100644 --- a/tests/integration/cqlengine/query/test_queryoperators.py +++ b/tests/integration/cqlengine/query/test_queryoperators.py @@ -25,6 +25,7 @@ from tests.integration.cqlengine import DEFAULT_KEYSPACE from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine import execute_count +import pytest class TestQuerySetOperation(BaseCassEngTestCase): @@ -37,10 +38,10 @@ def test_maxtimeuuid_function(self): where = WhereClause('time', EqualsOperator(), functions.MaxTimeUUID(now)) where.set_context_id(5) - self.assertEqual(str(where), '"time" = MaxTimeUUID(%(5)s)') + assert str(where) == '"time" = MaxTimeUUID(%(5)s)' ctx = {} where.update_context(ctx) - self.assertEqual(ctx, {'5': columns.DateTime().to_database(now)}) + assert ctx == {'5': columns.DateTime().to_database(now)} def test_mintimeuuid_function(self): """ @@ -50,10 +51,10 @@ def test_mintimeuuid_function(self): where = WhereClause('time', EqualsOperator(), functions.MinTimeUUID(now)) where.set_context_id(5) - self.assertEqual(str(where), '"time" = MinTimeUUID(%(5)s)') + assert str(where) == '"time" = MinTimeUUID(%(5)s)' ctx = {} where.update_context(ctx) - self.assertEqual(ctx, {'5': columns.DateTime().to_database(now)}) + assert ctx == {'5': columns.DateTime().to_database(now)} class TokenTestModel(Model): @@ -93,7 +94,7 @@ def test_token_function(self): # pk__token equality r = TokenTestModel.objects(pk__token=functions.Token(last_token)) - self.assertEqual(len(r), 1) + assert len(r) == 1 r.all() # Attempt to obtain queryset for results. This has thrown an exception in the past def test_compound_pk_token_function(self): @@ -108,7 +109,7 @@ class TestModel(Model): q = TestModel.objects.filter(pk__token__gt=func) where = q._where[0] where.set_context_id(1) - self.assertEqual(str(where), 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2)) + assert str(where) == 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2) # Verify that a SELECT query can be successfully generated str(q._select_query()) @@ -120,19 +121,22 @@ class TestModel(Model): q = TestModel.objects.filter(pk__token__gt=func) where = q._where[0] where.set_context_id(1) - self.assertEqual(str(where), 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2)) + assert str(where) == 'token("p1", "p2") > token(%({0})s, %({1})s)'.format(1, 2) str(q._select_query()) # The 'pk__token' virtual column may only be compared to a Token - self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=10) + with pytest.raises(query.QueryException): + TestModel.objects.filter(pk__token__gt=10) # A Token may only be compared to the `pk__token' virtual column func = functions.Token('a', 'b') - self.assertRaises(query.QueryException, TestModel.objects.filter, p1__gt=func) + with pytest.raises(query.QueryException): + TestModel.objects.filter(p1__gt=func) # The # of arguments to Token must match the # of partition keys func = functions.Token('a') - self.assertRaises(query.QueryException, TestModel.objects.filter, pk__token__gt=func) + with pytest.raises(query.QueryException): + TestModel.objects.filter(pk__token__gt=func) @execute_count(7) def test_named_table_pk_token_function(self): @@ -154,6 +158,6 @@ def test_named_table_pk_token_function(self): query = named.all().limit(1) first_page = list(query) last = first_page[-1] - self.assertTrue(len(first_page) == 1) + assert len(first_page) == 1 next_page = list(query.filter(pk__token__gt=functions.Token(last.key))) - self.assertTrue(len(next_page) == 1) + assert len(next_page) == 1 diff --git a/tests/integration/cqlengine/query/test_queryset.py b/tests/integration/cqlengine/query/test_queryset.py index d15390827f..34b4ab5964 100644 --- a/tests/integration/cqlengine/query/test_queryset.py +++ b/tests/integration/cqlengine/query/test_queryset.py @@ -41,6 +41,7 @@ from tests.integration import PROTOCOL_VERSION, CASSANDRA_VERSION, greaterthancass20, greaterthancass21, \ greaterthanorequalcass30, TestCluster, requires_collection_indexes from tests.integration.cqlengine import execute_count, DEFAULT_KEYSPACE +import pytest class TzOffset(tzinfo): @@ -131,8 +132,8 @@ def test_query_filter_parsing(self): assert len(query2._where) == 2 op = query2._where[1] - self.assertIsInstance(op, statements.WhereClause) - self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) assert op.value == 1 def test_query_expression_parsing(self): @@ -149,29 +150,29 @@ def test_query_expression_parsing(self): assert len(query2._where) == 2 op = query2._where[1] - self.assertIsInstance(op, statements.WhereClause) - self.assertIsInstance(op.operator, operators.GreaterThanOrEqualOperator) + assert isinstance(op, statements.WhereClause) + assert isinstance(op.operator, operators.GreaterThanOrEqualOperator) assert op.value == 1 def test_using_invalid_column_names_in_filter_kwargs_raises_error(self): """ Tests that using invalid or nonexistant column names for filter args raises an error """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(nonsense=5) def test_using_nonexistant_column_names_in_query_args_raises_error(self): """ Tests that using invalid or nonexistant columns for query args raises an error """ - with self.assertRaises(AttributeError): + with pytest.raises(AttributeError): TestModel.objects(TestModel.nonsense == 5) def test_using_non_query_operators_in_query_args_raises_error(self): """ Tests that providing query args that are not query operator instances raises an error """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(5) def test_queryset_is_immutable(self): @@ -218,13 +219,13 @@ def test_queryset_with_distinct(self): """ query1 = TestModel.objects.distinct() - self.assertEqual(len(query1._distinct_fields), 1) + assert len(query1._distinct_fields) == 1 query2 = TestModel.objects.distinct(['test_id']) - self.assertEqual(len(query2._distinct_fields), 1) + assert len(query2._distinct_fields) == 1 query3 = TestModel.objects.distinct(['test_id', 'attempt_id']) - self.assertEqual(len(query3._distinct_fields), 2) + assert len(query3._distinct_fields) == 2 def test_defining_only_fields(self): """ @@ -238,35 +239,35 @@ def test_defining_only_fields(self): """ # simple only definition q = TestModel.objects.only(['attempt_id', 'description']) - self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + assert q._select_fields() == ['attempt_id', 'description'] - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects.only(['nonexistent_field']) # Cannot define more than once only fields - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects.only(['description']).only(['attempt_id']) # only with defer fields q = TestModel.objects.only(['attempt_id', 'description']) q = q.defer(['description']) - self.assertEqual(q._select_fields(), ['attempt_id']) + assert q._select_fields() == ['attempt_id'] # Eliminate all results confirm exception is thrown q = TestModel.objects.only(['description']) q = q.defer(['description']) - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q._select_fields() q = TestModel.objects.filter(test_id=0).only(['test_id', 'attempt_id', 'description']) - self.assertEqual(q._select_fields(), ['attempt_id', 'description']) + assert q._select_fields() == ['attempt_id', 'description'] # no fields to select - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q = TestModel.objects.only(['test_id']).defer(['test_id']) q._select_fields() - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q = TestModel.objects.filter(test_id=0).only(['test_id']) q._select_fields() @@ -284,34 +285,34 @@ def test_defining_defer_fields(self): # simple defer definition q = TestModel.objects.defer(['attempt_id', 'description']) - self.assertEqual(q._select_fields(), ['test_id', 'expected_result', 'test_result']) + assert q._select_fields() == ['test_id', 'expected_result', 'test_result'] - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects.defer(['nonexistent_field']) # defer more than one q = TestModel.objects.defer(['attempt_id', 'description']) q = q.defer(['expected_result']) - self.assertEqual(q._select_fields(), ['test_id', 'test_result']) + assert q._select_fields() == ['test_id', 'test_result'] # defer with only q = TestModel.objects.defer(['description', 'attempt_id']) q = q.only(['description', 'test_id']) - self.assertEqual(q._select_fields(), ['test_id']) + assert q._select_fields() == ['test_id'] # Eliminate all results confirm exception is thrown q = TestModel.objects.defer(['description', 'attempt_id']) q = q.only(['description']) - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q._select_fields() # implicit defer q = TestModel.objects.filter(test_id=0) - self.assertEqual(q._select_fields(), ['attempt_id', 'description', 'expected_result', 'test_result']) + assert q._select_fields() == ['attempt_id', 'description', 'expected_result', 'test_result'] # when all fields are defered, it fallbacks select the partition keys q = TestModel.objects.defer(['test_id', 'attempt_id', 'description', 'expected_result', 'test_result']) - self.assertEqual(q._select_fields(), ['test_id']) + assert q._select_fields() == ['test_id'] class BaseQuerySetUsage(BaseCassEngTestCase): @@ -523,7 +524,7 @@ def test_get_doesnotexist_exception(self): """ Tests that get calls that don't return a result raises a DoesNotExist error """ - with self.assertRaises(TestModel.DoesNotExist): + with pytest.raises(TestModel.DoesNotExist): TestModel.objects.get(test_id=100) @execute_count(1) @@ -531,7 +532,7 @@ def test_get_multipleobjects_exception(self): """ Tests that get calls that return multiple results raise a MultipleObjectsReturned error """ - with self.assertRaises(TestModel.MultipleObjectsReturned): + with pytest.raises(TestModel.MultipleObjectsReturned): TestModel.objects.get(test_id=1) def test_allow_filtering_flag(self): @@ -566,37 +567,37 @@ class TestQuerySetDistinct(BaseQuerySetUsage): @execute_count(1) def test_distinct_without_parameter(self): q = TestModel.objects.distinct() - self.assertEqual(len(q), 3) + assert len(q) == 3 @execute_count(1) def test_distinct_with_parameter(self): q = TestModel.objects.distinct(['test_id']) - self.assertEqual(len(q), 3) + assert len(q) == 3 @execute_count(1) def test_distinct_with_filter(self): q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) - self.assertEqual(len(q), 2) + assert len(q) == 2 @execute_count(1) def test_distinct_with_non_partition(self): - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): q = TestModel.objects.distinct(['description']).filter(test_id__in=[1, 2]) len(q) @execute_count(1) def test_zero_result(self): q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[52]) - self.assertEqual(len(q), 0) + assert len(q) == 0 @greaterthancass21 @execute_count(2) def test_distinct_with_explicit_count(self): q = TestModel.objects.distinct(['test_id']) - self.assertEqual(q.count(), 3) + assert q.count() == 3 q = TestModel.objects.distinct(['test_id']).filter(test_id__in=[1, 2]) - self.assertEqual(q.count(), 2) + assert q.count() == 2 @requires_collection_indexes @@ -615,19 +616,19 @@ def test_order_by_success_case(self): def test_ordering_by_non_second_primary_keys_fail(self): # kwarg filtering - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(test_id=0).order_by('test_id') # kwarg filtering - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(TestModel.test_id == 0).order_by('test_id') def test_ordering_by_non_primary_keys_fails(self): - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(test_id=0).order_by('description') def test_ordering_on_indexed_columns_fails(self): - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): IndexedTestModel.objects(test_id=0).order_by('attempt_id') @execute_count(8) @@ -654,7 +655,7 @@ class TestQuerySetSlicing(BaseQuerySetUsage): @execute_count(1) def test_out_of_range_index_raises_error(self): q = TestModel.objects(test_id=0).order_by('attempt_id') - with self.assertRaises(IndexError): + with pytest.raises(IndexError): q[10] @execute_count(1) @@ -677,10 +678,10 @@ def test_slicing_works_properly(self): expected_order = [0, 1, 2, 3] for model, expect in zip(q[1:3], expected_order[1:3]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect for model, expect in zip(q[0:3:2], expected_order[0:3:2]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect @execute_count(1) def test_negative_slicing(self): @@ -688,19 +689,19 @@ def test_negative_slicing(self): expected_order = [0, 1, 2, 3] for model, expect in zip(q[-3:], expected_order[-3:]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect for model, expect in zip(q[:-1], expected_order[:-1]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect for model, expect in zip(q[1:-1], expected_order[1:-1]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect for model, expect in zip(q[-3:-1], expected_order[-3:-1]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect for model, expect in zip(q[-3:-1:2], expected_order[-3:-1:2]): - self.assertEqual(model.attempt_id, expect) + assert model.attempt_id == expect @requires_collection_indexes @@ -710,7 +711,7 @@ def test_primary_key_or_index_must_be_specified(self): """ Tests that queries that don't have an equals relation to a primary key or indexed field fail """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q = TestModel.objects(test_result=25) list([i for i in q]) @@ -718,7 +719,7 @@ def test_primary_key_or_index_must_have_equal_relation_filter(self): """ Tests that queries that don't have non equal (>,<, etc) relation to a primary key or indexed field fail """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): q = TestModel.objects(test_id__gt=0) list([i for i in q]) @@ -729,52 +730,52 @@ def test_indexed_field_can_be_queried(self): Tests that queries on an indexed field will work without any primary key relations specified """ q = IndexedTestModel.objects(test_result=25) - self.assertEqual(q.count(), 4) + assert q.count() == 4 q = IndexedCollectionsTestModel.objects.filter(test_list__contains=42) - self.assertEqual(q.count(), 1) + assert q.count() == 1 q = IndexedCollectionsTestModel.objects.filter(test_list__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.objects.filter(test_set__contains=42) - self.assertEqual(q.count(), 1) + assert q.count() == 1 q = IndexedCollectionsTestModel.objects.filter(test_set__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.objects.filter(test_map__contains=42) - self.assertEqual(q.count(), 1) + assert q.count() == 1 q = IndexedCollectionsTestModel.objects.filter(test_map__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 def test_custom_indexed_field_can_be_queried(self): """ Tests that queries on an custom indexed field will work without any primary key relations specified """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): list(CustomIndexedTestModel.objects.filter(data='test')) # not custom indexed # It should return InvalidRequest if target an indexed columns - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(CustomIndexedTestModel.objects.filter(indexed='test', data='test')) # It should return InvalidRequest if target an indexed columns - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(CustomIndexedTestModel.objects.filter(description='test', data='test')) # equals operator, server error since there is no real index, but it passes - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(CustomIndexedTestModel.objects.filter(description='test')) - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(CustomIndexedTestModel.objects.filter(test_id=1, description='test')) # gte operator, server error since there is no real index, but it passes # this can't work with a secondary index - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): list(CustomIndexedTestModel.objects.filter(description__gte='test')) with TestCluster().connect() as session: @@ -805,12 +806,12 @@ def test_delete(self): def test_delete_without_partition_key(self): """ Tests that attempting to delete a model without defining a partition key fails """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(attempt_id=0).delete() def test_delete_without_any_where_args(self): """ Tests that attempting to delete a whole table without any arguments will fail """ - with self.assertRaises(query.QueryException): + with pytest.raises(query.QueryException): TestModel.objects(attempt_id=0).delete() @greaterthanorequalcass30 @@ -824,16 +825,16 @@ def test_range_deletion(self): TestMultiClusteringModel.objects().create(one=1, two=i, three=i) TestMultiClusteringModel.objects(one=1, two__gte=0, two__lte=3).delete() - self.assertEqual(6, len(TestMultiClusteringModel.objects.all())) + assert 6 == len(TestMultiClusteringModel.objects.all()) TestMultiClusteringModel.objects(one=1, two__gt=3, two__lt=5).delete() - self.assertEqual(5, len(TestMultiClusteringModel.objects.all())) + assert 5 == len(TestMultiClusteringModel.objects.all()) TestMultiClusteringModel.objects(one=1, two__in=[8, 9]).delete() - self.assertEqual(3, len(TestMultiClusteringModel.objects.all())) + assert 3 == len(TestMultiClusteringModel.objects.all()) TestMultiClusteringModel.objects(one__in=[1], two__gte=0).delete() - self.assertEqual(0, len(TestMultiClusteringModel.objects.all())) + assert 0 == len(TestMultiClusteringModel.objects.all()) class TimeUUIDQueryModel(Model): @@ -912,7 +913,7 @@ def test_success_case(self): # test kwarg filtering q = TimeUUIDQueryModel.filter(partition=pk, time__lte=functions.MaxTimeUUID(midpoint)) q = [d for d in q] - self.assertEqual(len(q), 2, msg="Got: %s" % q) + assert len(q) == 2, "Got: %s" % q datas = [d.data for d in q] assert '1' in datas assert '2' in datas @@ -977,9 +978,9 @@ class bool_model(Model): bool_model.create(k=0, b=True) bool_model.create(k=0, b=False) - self.assertEqual(len(bool_model.objects.all()), 2) - self.assertEqual(len(bool_model.objects.filter(k=0, b=True)), 1) - self.assertEqual(len(bool_model.objects.filter(k=0, b=False)), 1) + assert len(bool_model.objects.all()) == 2 + assert len(bool_model.objects.filter(k=0, b=True)) == 1 + assert len(bool_model.objects.filter(k=0, b=False)) == 1 @execute_count(3) def test_bool_filter(self): @@ -1001,7 +1002,7 @@ class bool_model2(Model): bool_model2.create(k=True, b=1, v='a') bool_model2.create(k=False, b=1, v='b') - self.assertEqual(len(list(bool_model2.objects(k__in=(True, False)))), 2) + assert len(list(bool_model2.objects(k__in=(True, False)))) == 2 @greaterthancass20 @@ -1012,63 +1013,63 @@ class TestContainsOperator(BaseQuerySetUsage): def test_kwarg_success_case(self): """ Tests the CONTAINS operator works with the kwarg query method """ q = IndexedCollectionsTestModel.filter(test_list__contains=1) - self.assertEqual(q.count(), 2) + assert q.count() == 2 q = IndexedCollectionsTestModel.filter(test_list__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.filter(test_set__contains=3) - self.assertEqual(q.count(), 2) + assert q.count() == 2 q = IndexedCollectionsTestModel.filter(test_set__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.filter(test_map__contains=42) - self.assertEqual(q.count(), 1) + assert q.count() == 1 q = IndexedCollectionsTestModel.filter(test_map__contains=13) - self.assertEqual(q.count(), 0) + assert q.count() == 0 - with self.assertRaises(QueryException): + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(test_list_no_index__contains=1) - self.assertEqual(q.count(), 0) - with self.assertRaises(QueryException): + assert q.count() == 0 + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(test_set_no_index__contains=1) - self.assertEqual(q.count(), 0) - with self.assertRaises(QueryException): + assert q.count() == 0 + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(test_map_no_index__contains=1) - self.assertEqual(q.count(), 0) + assert q.count() == 0 @execute_count(6) def test_query_expression_success_case(self): """ Tests the CONTAINS operator works with the query expression query method """ q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(1)) - self.assertEqual(q.count(), 2) + assert q.count() == 2 q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_list.contains_(13)) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(3)) - self.assertEqual(q.count(), 2) + assert q.count() == 2 q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_set.contains_(13)) - self.assertEqual(q.count(), 0) + assert q.count() == 0 q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(42)) - self.assertEqual(q.count(), 1) + assert q.count() == 1 q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map.contains_(13)) - self.assertEqual(q.count(), 0) + assert q.count() == 0 - with self.assertRaises(QueryException): + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) - self.assertEqual(q.count(), 0) - with self.assertRaises(QueryException): + assert q.count() == 0 + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) - self.assertEqual(q.count(), 0) - with self.assertRaises(QueryException): + assert q.count() == 0 + with pytest.raises(QueryException): q = IndexedCollectionsTestModel.filter(IndexedCollectionsTestModel.test_map_no_index.contains_(1)) - self.assertEqual(q.count(), 0) + assert q.count() == 0 @requires_collection_indexes @@ -1120,17 +1121,17 @@ class ModelQuerySetTimeoutTestCase(BaseQuerySetUsage): def test_default_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: list(TestModel.objects()) - self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + assert mock_execute.call_args[-1]['timeout'] == NOT_SET def test_float_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: list(TestModel.objects().timeout(0.5)) - self.assertEqual(mock_execute.call_args[-1]['timeout'], 0.5) + assert mock_execute.call_args[-1]['timeout'] == 0.5 def test_none_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: list(TestModel.objects().timeout(None)) - self.assertEqual(mock_execute.call_args[-1]['timeout'], None) + assert mock_execute.call_args[-1]['timeout'] == None @requires_collection_indexes @@ -1142,28 +1143,28 @@ def setUp(self): def test_default_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: self.model.save() - self.assertEqual(mock_execute.call_args[-1]['timeout'], NOT_SET) + assert mock_execute.call_args[-1]['timeout'] == NOT_SET def test_float_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: self.model.timeout(0.5).save() - self.assertEqual(mock_execute.call_args[-1]['timeout'], 0.5) + assert mock_execute.call_args[-1]['timeout'] == 0.5 def test_none_timeout(self): with mock.patch.object(Session, 'execute') as mock_execute: self.model.timeout(None).save() - self.assertEqual(mock_execute.call_args[-1]['timeout'], None) + assert mock_execute.call_args[-1]['timeout'] == None def test_timeout_then_batch(self): b = query.BatchQuery() m = self.model.timeout(None) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): m.batch(b) def test_batch_then_timeout(self): b = query.BatchQuery() m = self.model.batch(b) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): m.timeout(0.5) @@ -1220,26 +1221,26 @@ def test_basic_crud(self): # create i = model.create(**values) i = model.objects(k0=i.k0, k1=i.k1).first() - self.assertEqual(i, model(**values)) + assert i == model(**values) # create values['v0'] = 101 i.update(v0=values['v0']) i = model.objects(k0=i.k0, k1=i.k1).first() - self.assertEqual(i, model(**values)) + assert i == model(**values) # delete model.objects(k0=i.k0, k1=i.k1).delete() i = model.objects(k0=i.k0, k1=i.k1).first() - self.assertIsNone(i) + assert i is None i = model.create(**values) i = model.objects(k0=i.k0, k1=i.k1).first() - self.assertEqual(i, model(**values)) + assert i == model(**values) i.delete() model.objects(k0=i.k0, k1=i.k1).delete() i = model.objects(k0=i.k0, k1=i.k1).first() - self.assertIsNone(i) + assert i is None @execute_count(21) def test_slice(self): @@ -1259,10 +1260,10 @@ def test_slice(self): values['c0'] = c i = model.create(**values) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1).count(), len(clustering_values)) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0=i.c0).count(), 1) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count(), len(clustering_values[:-1])) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count(), len(clustering_values[1:])) + assert model.objects(k0=i.k0, k1=i.k1).count() == len(clustering_values) + assert model.objects(k0=i.k0, k1=i.k1, c0=i.c0).count() == 1 + assert model.objects(k0=i.k0, k1=i.k1, c0__lt=i.c0).count() == len(clustering_values[:-1]) + assert model.objects(k0=i.k0, k1=i.k1, c0__gt=0).count() == len(clustering_values[1:]) @execute_count(15) def test_order(self): @@ -1281,8 +1282,8 @@ def test_order(self): for c in clustering_values: values['c0'] = c i = model.create(**values) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('c0').first().c0, clustering_values[0]) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1).order_by('-c0').first().c0, clustering_values[-1]) + assert model.objects(k0=i.k0, k1=i.k1).order_by('c0').first().c0 == clustering_values[0] + assert model.objects(k0=i.k0, k1=i.k1).order_by('-c0').first().c0 == clustering_values[-1] @execute_count(15) def test_index(self): @@ -1302,8 +1303,8 @@ def test_index(self): values['c0'] = c values['v1'] = c i = model.create(**values) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1).count(), len(clustering_values)) - self.assertEqual(model.objects(k0=i.k0, k1=i.k1, v1=0).count(), 1) + assert model.objects(k0=i.k0, k1=i.k1).count() == len(clustering_values) + assert model.objects(k0=i.k0, k1=i.k1, v1=0).count() == 1 @execute_count(1) def test_db_field_names_used(self): @@ -1325,7 +1326,7 @@ def test_db_field_names_used(self): v1=9, ) for value in values: - self.assertTrue(value not in str(b.queries[0])) + assert value not in str(b.queries[0]) # Test DML path b2 = BatchQuery() @@ -1335,15 +1336,13 @@ def test_db_field_names_used(self): v1=9, ) for value in values: - self.assertTrue(value not in str(b2.queries[0])) + assert value not in str(b2.queries[0]) def test_db_field_value_list(self): DBFieldModel.create(k0=0, k1=0, c0=0, v0=4, v1=5) - self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._defer_fields, - {'a', 'c', 'b'}) - self.assertEqual(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._only_fields, - ['c', 'd']) + assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._defer_fields == {'a', 'c', 'b'} + assert DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')._only_fields == ['c', 'd'] list(DBFieldModel.objects.filter(c0=0, k0=0, k1=0).values_list('c0', 'v0')) @@ -1390,18 +1389,18 @@ def test_defaultFetchSize(self): for i in range(5000, 5100): TestModelSmall.batch(b).create(test_id=i) - self.assertEqual(len(TestModelSmall.objects.fetch_size(1)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(500)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(4999)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(5000)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(5001)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(5100)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(5101)), 5100) - self.assertEqual(len(TestModelSmall.objects.fetch_size(1)), 5100) + assert len(TestModelSmall.objects.fetch_size(1)) == 5100 + assert len(TestModelSmall.objects.fetch_size(500)) == 5100 + assert len(TestModelSmall.objects.fetch_size(4999)) == 5100 + assert len(TestModelSmall.objects.fetch_size(5000)) == 5100 + assert len(TestModelSmall.objects.fetch_size(5001)) == 5100 + assert len(TestModelSmall.objects.fetch_size(5100)) == 5100 + assert len(TestModelSmall.objects.fetch_size(5101)) == 5100 + assert len(TestModelSmall.objects.fetch_size(1)) == 5100 - with self.assertRaises(QueryException): + with pytest.raises(QueryException): TestModelSmall.objects.fetch_size(0) - with self.assertRaises(QueryException): + with pytest.raises(QueryException): TestModelSmall.objects.fetch_size(-1) @@ -1453,11 +1452,11 @@ def test_defaultFetchSize(self): # Check query constructions expected_fields = ['first_name', 'birthday'] - self.assertEqual(People.filter(last_name="Smith")._select_fields(), expected_fields) + assert People.filter(last_name="Smith")._select_fields() == expected_fields # Validate correct fields are fetched smiths = list(People.filter(last_name="Smith")) - self.assertEqual(len(smiths), 3) - self.assertTrue(smiths[0].last_name is not None) + assert len(smiths) == 3 + assert smiths[0].last_name is not None # Modify table with new value sync_table(People2) @@ -1468,9 +1467,9 @@ def test_defaultFetchSize(self): # validate query construction expected_fields = ['first_name', 'middle_name', 'birthday'] - self.assertEqual(People2.filter(last_name="Smith")._select_fields(), expected_fields) + assert People2.filter(last_name="Smith")._select_fields() == expected_fields # validate correct items are returneds smiths = list(People2.filter(last_name="Smith")) - self.assertEqual(len(smiths), 5) - self.assertTrue(smiths[0].last_name is not None) + assert len(smiths) == 5 + assert smiths[0].last_name is not None diff --git a/tests/integration/cqlengine/query/test_updates.py b/tests/integration/cqlengine/query/test_updates.py index f92e4fc53f..cedde0cd7b 100644 --- a/tests/integration/cqlengine/query/test_updates.py +++ b/tests/integration/cqlengine/query/test_updates.py @@ -23,6 +23,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase, TestQueryUpdateModel from tests.integration.cqlengine import execute_count from tests.integration import greaterthancass20 +import pytest class QueryUpdateTests(BaseCassEngTestCase): @@ -46,17 +47,17 @@ def test_update_values(self): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, i) - self.assertEqual(row.text, str(i)) + assert row.cluster == i + assert row.count == i + assert row.text == str(i) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, 6 if i == 3 else i) - self.assertEqual(row.text, str(i)) + assert row.cluster == i + assert row.count == (6 if i == 3 else i) + assert row.text == str(i) @execute_count(6) def test_update_values_validation(self): @@ -67,22 +68,22 @@ def test_update_values_validation(self): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, i) - self.assertEqual(row.text, str(i)) + assert row.cluster == i + assert row.count == i + assert row.text == str(i) # perform update - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count='asdf') def test_invalid_update_kwarg(self): """ tests that passing in a kwarg to the update method that isn't a column will fail """ - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(bacon=5000) def test_primary_key_update_failure(self): """ tests that attempting to update the value of a primary key will fail """ - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): TestQueryUpdateModel.objects(partition=uuid4(), cluster=3).update(cluster=5000) @execute_count(8) @@ -94,17 +95,17 @@ def test_null_update_deletes_column(self): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, i) - self.assertEqual(row.text, str(i)) + assert row.cluster == i + assert row.count == i + assert row.text == str(i) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(text=None) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, i) - self.assertEqual(row.text, None if i == 3 else str(i)) + assert row.cluster == i + assert row.count == i + assert row.text == None if i == 3 else str(i) @execute_count(9) def test_mixed_value_and_null_update(self): @@ -115,17 +116,17 @@ def test_mixed_value_and_null_update(self): # sanity check for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, i) - self.assertEqual(row.text, str(i)) + assert row.cluster == i + assert row.count == i + assert row.text == str(i) # perform update TestQueryUpdateModel.objects(partition=partition, cluster=3).update(count=6, text=None) for i, row in enumerate(TestQueryUpdateModel.objects(partition=partition)): - self.assertEqual(row.cluster, i) - self.assertEqual(row.count, 6 if i == 3 else i) - self.assertEqual(row.text, None if i == 3 else str(i)) + assert row.cluster == i + assert row.count == (6 if i == 3 else i) + assert row.text == (None if i == 3 else str(i)) @execute_count(3) def test_set_add_updates(self): @@ -136,7 +137,7 @@ def test_set_add_updates(self): TestQueryUpdateModel.objects( partition=partition, cluster=cluster).update(text_set__add=set(('bar',))) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_set, set(("foo", "bar"))) + assert obj.text_set == set(("foo", "bar")) @execute_count(2) def test_set_add_updates_new_record(self): @@ -147,7 +148,7 @@ def test_set_add_updates_new_record(self): TestQueryUpdateModel.objects( partition=partition, cluster=cluster).update(text_set__add=set(('bar',))) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_set, set(("bar",))) + assert obj.text_set == set(("bar",)) @execute_count(3) def test_set_remove_updates(self): @@ -159,7 +160,7 @@ def test_set_remove_updates(self): partition=partition, cluster=cluster).update( text_set__remove=set(('foo',))) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_set, set(("baz",))) + assert obj.text_set == set(("baz",)) @execute_count(3) def test_set_remove_new_record(self): @@ -173,7 +174,7 @@ def test_set_remove_new_record(self): partition=partition, cluster=cluster).update( text_set__remove=set(('afsd',))) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_set, set(("foo",))) + assert obj.text_set == set(("foo",)) @execute_count(3) def test_list_append_updates(self): @@ -185,7 +186,7 @@ def test_list_append_updates(self): partition=partition, cluster=cluster).update( text_list__append=['bar']) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_list, ["foo", "bar"]) + assert obj.text_list == ["foo", "bar"] @execute_count(3) def test_list_prepend_updates(self): @@ -201,7 +202,7 @@ def test_list_prepend_updates(self): text_list__prepend=prepended) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) expected = (prepended[::-1] if is_prepend_reversed() else prepended) + original - self.assertEqual(obj.text_list, expected) + assert obj.text_list == expected @execute_count(3) def test_map_update_updates(self): @@ -215,7 +216,7 @@ def test_map_update_updates(self): partition=partition, cluster=cluster).update( text_map__update={"bar": '3', "baz": '4'}) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_map, {"foo": '1', "bar": '3', "baz": '4'}) + assert obj.text_map == {"foo": '1', "bar": '3', "baz": '4'} @execute_count(3) def test_map_update_none_deletes_key(self): @@ -231,7 +232,7 @@ def test_map_update_none_deletes_key(self): partition=partition, cluster=cluster).update( text_map__update={"bar": None}) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_map, {"foo": '1'}) + assert obj.text_map == {"foo": '1'} @greaterthancass20 @execute_count(5) @@ -256,22 +257,16 @@ def test_map_update_remove(self): bin_map__update={456: b'4', 123: b'2'} ) obj = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual(obj.text_map, {"foo": '2', "foz": '4'}) - self.assertEqual(obj.bin_map, {123: b'2', 456: b'4'}) + assert obj.text_map == {"foo": '2', "foz": '4'} + assert obj.bin_map == {123: b'2', 456: b'4'} TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( text_map__remove={"foo", "foz"}, bin_map__remove={123, 456} ) rec = TestQueryUpdateModel.objects.get(partition=partition, cluster=cluster) - self.assertEqual( - rec.text_map, - {} - ) - self.assertEqual( - rec.bin_map, - {} - ) + assert rec.text_map == {} + assert rec.bin_map == {} def test_map_remove_rejects_non_sets(self): """ @@ -286,7 +281,7 @@ def test_map_remove_rejects_non_sets(self): cluster=cluster, text_map={"foo": '1', "bar": '2'} ) - with self.assertRaises(ValidationError): + with pytest.raises(ValidationError): TestQueryUpdateModel.objects(partition=partition, cluster=cluster).update( text_map__remove=["bar"] ) @@ -314,7 +309,7 @@ def test_an_extra_delete_is_not_sent(self): obj = TestQueryUpdateModel.objects( partition=partition, cluster=cluster).first() - self.assertFalse({k: v for (k, v) in obj._values.items() if v.deleted}) + assert not {k: v for (k, v) in obj._values.items() if v.deleted} obj.text = 'foo' obj.save() @@ -352,6 +347,6 @@ def test_static_deletion(self): """ StaticDeleteModel.create(example_id=5, example_clust=5, example_static2=1) sdm = StaticDeleteModel.filter(example_id=5).first() - self.assertEqual(1, sdm.example_static2) + assert 1 == sdm.example_static2 sdm.update(example_static2=None) - self.assertIsNone(sdm.example_static2) + assert sdm.example_static2 is None diff --git a/tests/integration/cqlengine/statements/test_assignment_clauses.py b/tests/integration/cqlengine/statements/test_assignment_clauses.py index 82bf067cb4..dce910fd5e 100644 --- a/tests/integration/cqlengine/statements/test_assignment_clauses.py +++ b/tests/integration/cqlengine/statements/test_assignment_clauses.py @@ -24,7 +24,7 @@ def test_rendering(self): def test_insert_tuple(self): ac = AssignmentClause('a', 'b') ac.set_context_id(10) - self.assertEqual(ac.insert_tuple(), ('a', 10)) + assert ac.insert_tuple() == ('a', 10) class SetUpdateClauseTests(unittest.TestCase): @@ -34,16 +34,16 @@ def test_update_from_none(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, set((1, 2))) - self.assertIsNone(c._additions) - self.assertIsNone(c._removals) + assert c._assignments == set((1, 2)) + assert c._additions is None + assert c._removals is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': set((1, 2))}) + assert ctx == {'0': set((1, 2))} def test_null_update(self): """ tests setting a set to None creates an empty update statement """ @@ -51,16 +51,16 @@ def test_null_update(self): c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertIsNone(c._additions) - self.assertIsNone(c._removals) + assert c._assignments is None + assert c._additions is None + assert c._removals is None - self.assertEqual(c.get_context_size(), 0) - self.assertEqual(str(c), '') + assert c.get_context_size() == 0 + assert str(c) == '' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {}) + assert ctx == {} def test_no_update(self): """ tests an unchanged value creates an empty update statement """ @@ -68,16 +68,16 @@ def test_no_update(self): c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertIsNone(c._additions) - self.assertIsNone(c._removals) + assert c._assignments is None + assert c._additions is None + assert c._removals is None - self.assertEqual(c.get_context_size(), 0) - self.assertEqual(str(c), '') + assert c.get_context_size() == 0 + assert str(c) == '' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {}) + assert ctx == {} def test_update_empty_set(self): """tests assigning a set to an empty set creates a nonempty @@ -86,64 +86,64 @@ def test_update_empty_set(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, set()) - self.assertIsNone(c._additions) - self.assertIsNone(c._removals) + assert c._assignments == set() + assert c._additions is None + assert c._removals is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': set()}) + assert ctx == {'0': set()} def test_additions(self): c = SetUpdateClause('s', set((1, 2, 3)), previous=set((1, 2))) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertEqual(c._additions, set((3,))) - self.assertIsNone(c._removals) + assert c._assignments is None + assert c._additions == set((3,)) + assert c._removals is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = "s" + %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = "s" + %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': set((3,))}) + assert ctx == {'0': set((3,))} def test_removals(self): c = SetUpdateClause('s', set((1, 2)), previous=set((1, 2, 3))) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertIsNone(c._additions) - self.assertEqual(c._removals, set((3,))) + assert c._assignments is None + assert c._additions is None + assert c._removals == set((3,)) - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = "s" - %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = "s" - %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': set((3,))}) + assert ctx == {'0': set((3,))} def test_additions_and_removals(self): c = SetUpdateClause('s', set((2, 3)), previous=set((1, 2))) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertEqual(c._additions, set((3,))) - self.assertEqual(c._removals, set((1,))) + assert c._assignments is None + assert c._additions == set((3,)) + assert c._removals == set((1,)) - self.assertEqual(c.get_context_size(), 2) - self.assertEqual(str(c), '"s" = "s" + %(0)s, "s" = "s" - %(1)s') + assert c.get_context_size() == 2 + assert str(c) == '"s" = "s" + %(0)s, "s" = "s" - %(1)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': set((3,)), '1': set((1,))}) + assert ctx == {'0': set((3,)), '1': set((1,))} class ListUpdateClauseTests(unittest.TestCase): @@ -153,96 +153,96 @@ def test_update_from_none(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, [1, 2, 3]) - self.assertIsNone(c._append) - self.assertIsNone(c._prepend) + assert c._assignments == [1, 2, 3] + assert c._append is None + assert c._prepend is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2, 3]}) + assert ctx == {'0': [1, 2, 3]} def test_update_from_empty(self): c = ListUpdateClause('s', [1, 2, 3], previous=[]) c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, [1, 2, 3]) - self.assertIsNone(c._append) - self.assertIsNone(c._prepend) + assert c._assignments == [1, 2, 3] + assert c._append is None + assert c._prepend is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2, 3]}) + assert ctx == {'0': [1, 2, 3]} def test_update_from_different_list(self): c = ListUpdateClause('s', [1, 2, 3], previous=[3, 2, 1]) c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, [1, 2, 3]) - self.assertIsNone(c._append) - self.assertIsNone(c._prepend) + assert c._assignments == [1, 2, 3] + assert c._append is None + assert c._prepend is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2, 3]}) + assert ctx == {'0': [1, 2, 3]} def test_append(self): c = ListUpdateClause('s', [1, 2, 3, 4], previous=[1, 2]) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertEqual(c._append, [3, 4]) - self.assertIsNone(c._prepend) + assert c._assignments is None + assert c._append == [3, 4] + assert c._prepend is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = "s" + %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = "s" + %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [3, 4]}) + assert ctx == {'0': [3, 4]} def test_prepend(self): c = ListUpdateClause('s', [1, 2, 3, 4], previous=[3, 4]) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertIsNone(c._append) - self.assertEqual(c._prepend, [1, 2]) + assert c._assignments is None + assert c._append is None + assert c._prepend == [1, 2] - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s + "s"') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s + "s"' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2]}) + assert ctx == {'0': [1, 2]} def test_append_and_prepend(self): c = ListUpdateClause('s', [1, 2, 3, 4, 5, 6], previous=[3, 4]) c._analyze() c.set_context_id(0) - self.assertIsNone(c._assignments) - self.assertEqual(c._append, [5, 6]) - self.assertEqual(c._prepend, [1, 2]) + assert c._assignments is None + assert c._append == [5, 6] + assert c._prepend == [1, 2] - self.assertEqual(c.get_context_size(), 2) - self.assertEqual(str(c), '"s" = %(0)s + "s", "s" = "s" + %(1)s') + assert c.get_context_size() == 2 + assert str(c) == '"s" = %(0)s + "s", "s" = "s" + %(1)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2], '1': [5, 6]}) + assert ctx == {'0': [1, 2], '1': [5, 6]} def test_shrinking_list_update(self): """ tests that updating to a smaller list results in an insert statement """ @@ -250,16 +250,16 @@ def test_shrinking_list_update(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._assignments, [1, 2, 3]) - self.assertIsNone(c._append) - self.assertIsNone(c._prepend) + assert c._assignments == [1, 2, 3] + assert c._append is None + assert c._prepend is None - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"s" = %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"s" = %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': [1, 2, 3]}) + assert ctx == {'0': [1, 2, 3]} class MapUpdateTests(unittest.TestCase): @@ -269,33 +269,33 @@ def test_update(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._updates, [3, 5]) - self.assertEqual(c.get_context_size(), 4) - self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + assert c._updates == [3, 5] + assert c.get_context_size() == 4 + assert str(c) == '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + assert ctx == {'0': 3, "1": 0, '2': 5, '3': 6} def test_update_from_null(self): c = MapUpdateClause('s', {3: 0, 5: 6}) c._analyze() c.set_context_id(0) - self.assertEqual(c._updates, [3, 5]) - self.assertEqual(c.get_context_size(), 4) - self.assertEqual(str(c), '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s') + assert c._updates == [3, 5] + assert c.get_context_size() == 4 + assert str(c) == '"s"[%(0)s] = %(1)s, "s"[%(2)s] = %(3)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': 3, "1": 0, '2': 5, '3': 6}) + assert ctx == {'0': 3, "1": 0, '2': 5, '3': 6} def test_nulled_columns_arent_included(self): c = MapUpdateClause('s', {3: 0}, {1: 2, 3: 4}) c._analyze() c.set_context_id(0) - self.assertNotIn(1, c._updates) + assert 1 not in c._updates class CounterUpdateTests(unittest.TestCase): @@ -304,34 +304,34 @@ def test_positive_update(self): c = CounterUpdateClause('a', 5, 3) c.set_context_id(5) - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"a" = "a" + %(5)s') + assert c.get_context_size() == 1 + assert str(c) == '"a" = "a" + %(5)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'5': 2}) + assert ctx == {'5': 2} def test_negative_update(self): c = CounterUpdateClause('a', 4, 7) c.set_context_id(3) - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"a" = "a" - %(3)s') + assert c.get_context_size() == 1 + assert str(c) == '"a" = "a" - %(3)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'3': 3}) + assert ctx == {'3': 3} def noop_update(self): c = CounterUpdateClause('a', 5, 5) c.set_context_id(5) - self.assertEqual(c.get_context_size(), 1) - self.assertEqual(str(c), '"a" = "a" + %(0)s') + assert c.get_context_size() == 1 + assert str(c) == '"a" = "a" + %(0)s' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'5': 0}) + assert ctx == {'5': 0} class MapDeleteTests(unittest.TestCase): @@ -341,13 +341,13 @@ def test_update(self): c._analyze() c.set_context_id(0) - self.assertEqual(c._removals, [1, 5]) - self.assertEqual(c.get_context_size(), 2) - self.assertEqual(str(c), '"s"[%(0)s], "s"[%(1)s]') + assert c._removals == [1, 5] + assert c.get_context_size() == 2 + assert str(c) == '"s"[%(0)s], "s"[%(1)s]' ctx = {} c.update_context(ctx) - self.assertEqual(ctx, {'0': 1, '1': 5}) + assert ctx == {'0': 1, '1': 5} class FieldDeleteTests(unittest.TestCase): diff --git a/tests/integration/cqlengine/statements/test_base_statement.py b/tests/integration/cqlengine/statements/test_base_statement.py index bbcdc700e1..a2af184235 100644 --- a/tests/integration/cqlengine/statements/test_base_statement.py +++ b/tests/integration/cqlengine/statements/test_base_statement.py @@ -35,13 +35,13 @@ class BaseStatementTest(unittest.TestCase): def test_fetch_size(self): """ tests that fetch_size is correctly set """ stmt = BaseCQLStatement('table', None, fetch_size=1000) - self.assertEqual(stmt.fetch_size, 1000) + assert stmt.fetch_size == 1000 stmt = BaseCQLStatement('table', None, fetch_size=None) - self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + assert stmt.fetch_size == FETCH_SIZE_UNSET stmt = BaseCQLStatement('table', None) - self.assertEqual(stmt.fetch_size, FETCH_SIZE_UNSET) + assert stmt.fetch_size == FETCH_SIZE_UNSET class ExecuteStatementTest(BaseCassEngTestCase): @@ -64,8 +64,8 @@ def _verify_statement(self, original): response = result.one() for assignment in original.assignments: - self.assertEqual(response[assignment.field], assignment.value) - self.assertEqual(len(response), 8) + assert response[assignment.field] == assignment.value + assert len(response) == 8 def test_insert_statement_execute(self): """ @@ -99,7 +99,7 @@ def test_insert_statement_execute(self): # Verifying delete statement execute(DeleteStatement(self.table_name, where=where)) - self.assertEqual(TestQueryUpdateModel.objects.count(), 0) + assert TestQueryUpdateModel.objects.count() == 0 @greaterthanorequalcass3_10 @requires_custom_indexes @@ -128,17 +128,16 @@ def test_like_operator(self): ss = SelectStatement(self.table_name) like_clause = "text_for_%" ss.add_where(Column(db_field='text'), LikeOperator(), like_clause) - self.assertEqual(str(ss), - 'SELECT * FROM {} WHERE "text" LIKE %(0)s'.format(self.table_name)) + assert str(ss) == 'SELECT * FROM {} WHERE "text" LIKE %(0)s'.format(self.table_name) result = execute(ss) - self.assertEqual(result[0]["text"], self.text) + assert result[0]["text"] == self.text q = TestQueryUpdateModel.objects.filter(text__like=like_clause).allow_filtering() - self.assertEqual(q[0].text, self.text) + assert q[0].text == self.text q = TestQueryUpdateModel.objects.filter(text__like=like_clause) - self.assertEqual(q[0].text, self.text) + assert q[0].text == self.text def _insert_statement(self, partition, cluster): # Verifying insert statement diff --git a/tests/integration/cqlengine/statements/test_delete_statement.py b/tests/integration/cqlengine/statements/test_delete_statement.py index 745881f42f..c5335e94c4 100644 --- a/tests/integration/cqlengine/statements/test_delete_statement.py +++ b/tests/integration/cqlengine/statements/test_delete_statement.py @@ -24,30 +24,30 @@ class DeleteStatementTests(TestCase): def test_single_field_is_listified(self): """ tests that passing a string field into the constructor puts it into a list """ ds = DeleteStatement('table', 'field') - self.assertEqual(len(ds.fields), 1) - self.assertEqual(ds.fields[0].field, 'field') + assert len(ds.fields) == 1 + assert ds.fields[0].field == 'field' def test_field_rendering(self): """ tests that fields are properly added to the select statement """ ds = DeleteStatement('table', ['f1', 'f2']) - self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) - self.assertTrue(str(ds).startswith('DELETE "f1", "f2"'), str(ds)) + assert str(ds).startswith('DELETE "f1", "f2"'), str(ds) + assert str(ds).startswith('DELETE "f1", "f2"'), str(ds) def test_none_fields_rendering(self): """ tests that a '*' is added if no fields are passed in """ ds = DeleteStatement('table', None) - self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) - self.assertTrue(str(ds).startswith('DELETE FROM'), str(ds)) + assert str(ds).startswith('DELETE FROM'), str(ds) + assert str(ds).startswith('DELETE FROM'), str(ds) def test_table_rendering(self): ds = DeleteStatement('table', None) - self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) - self.assertTrue(str(ds).startswith('DELETE FROM table'), str(ds)) + assert str(ds).startswith('DELETE FROM table'), str(ds) + assert str(ds).startswith('DELETE FROM table'), str(ds) def test_where_clause_rendering(self): ds = DeleteStatement('table', None) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s', str(ds)) + assert str(ds) == 'DELETE FROM table WHERE "a" = %(0)s', str(ds) def test_context_update(self): ds = DeleteStatement('table', None) @@ -55,36 +55,36 @@ def test_context_update(self): ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') ds.update_context_id(7) - self.assertEqual(str(ds), 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s') - self.assertEqual(ds.get_context(), {'7': 'b', '8': 3}) + assert str(ds) == 'DELETE "d"[%(8)s] FROM table WHERE "a" = %(7)s' + assert ds.get_context() == {'7': 'b', '8': 3} def test_context(self): ds = DeleteStatement('table', None) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(ds.get_context(), {'0': 'b'}) + assert ds.get_context() == {'0': 'b'} def test_range_deletion_rendering(self): ds = DeleteStatement('table', None) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') ds.add_where(Column(db_field='created_at'), GreaterThanOrEqualOperator(), '0') ds.add_where(Column(db_field='created_at'), LessThanOrEqualOperator(), '10') - self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', str(ds)) + assert str(ds) == 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" >= %(1)s AND "created_at" <= %(2)s', str(ds) ds = DeleteStatement('table', None) ds.add_where(Column(db_field='a'), EqualsOperator(), 'b') ds.add_where(Column(db_field='created_at'), InOperator(), ['0', '10', '20']) - self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', str(ds)) + assert str(ds) == 'DELETE FROM table WHERE "a" = %(0)s AND "created_at" IN %(1)s', str(ds) ds = DeleteStatement('table', None) ds.add_where(Column(db_field='a'), NotEqualsOperator(), 'b') - self.assertEqual(str(ds), 'DELETE FROM table WHERE "a" != %(0)s', str(ds)) + assert str(ds) == 'DELETE FROM table WHERE "a" != %(0)s', str(ds) def test_delete_conditional(self): where = [WhereClause('id', EqualsOperator(), 1)] conditionals = [ConditionalClause('f0', 'value0'), ConditionalClause('f1', 'value1')] ds = DeleteStatement('table', where=where, conditionals=conditionals) - self.assertEqual(len(ds.conditionals), len(conditionals)) - self.assertEqual(str(ds), 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds)) + assert len(ds.conditionals) == len(conditionals) + assert str(ds) == 'DELETE FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds) fields = ['one', 'two'] ds = DeleteStatement('table', fields=fields, where=where, conditionals=conditionals) - self.assertEqual(str(ds), 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds)) + assert str(ds) == 'DELETE "one", "two" FROM table WHERE "id" = %(0)s IF "f0" = %(1)s AND "f1" = %(2)s', str(ds) diff --git a/tests/integration/cqlengine/statements/test_insert_statement.py b/tests/integration/cqlengine/statements/test_insert_statement.py index 45485af912..0f70253a6c 100644 --- a/tests/integration/cqlengine/statements/test_insert_statement.py +++ b/tests/integration/cqlengine/statements/test_insert_statement.py @@ -24,10 +24,7 @@ def test_statement(self): ist.add_assignment(Column(db_field='a'), 'b') ist.add_assignment(Column(db_field='c'), 'd') - self.assertEqual( - str(ist), - 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' - ) + assert str(ist) == 'INSERT INTO table ("a", "c") VALUES (%(0)s, %(1)s)' def test_context_update(self): ist = InsertStatement('table', None) @@ -35,15 +32,12 @@ def test_context_update(self): ist.add_assignment(Column(db_field='c'), 'd') ist.update_context_id(4) - self.assertEqual( - str(ist), - 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' - ) + assert str(ist) == 'INSERT INTO table ("a", "c") VALUES (%(4)s, %(5)s)' ctx = ist.get_context() - self.assertEqual(ctx, {'4': 'b', '5': 'd'}) + assert ctx == {'4': 'b', '5': 'd'} def test_additional_rendering(self): ist = InsertStatement('table', ttl=60) ist.add_assignment(Column(db_field='a'), 'b') ist.add_assignment(Column(db_field='c'), 'd') - self.assertIn('USING TTL 60', str(ist)) + assert 'USING TTL 60' in str(ist) diff --git a/tests/integration/cqlengine/statements/test_select_statement.py b/tests/integration/cqlengine/statements/test_select_statement.py index 26c9c804cb..b4bada1eb0 100644 --- a/tests/integration/cqlengine/statements/test_select_statement.py +++ b/tests/integration/cqlengine/statements/test_select_statement.py @@ -22,63 +22,63 @@ class SelectStatementTests(unittest.TestCase): def test_single_field_is_listified(self): """ tests that passing a string field into the constructor puts it into a list """ ss = SelectStatement('table', 'field') - self.assertEqual(ss.fields, ['field']) + assert ss.fields == ['field'] def test_field_rendering(self): """ tests that fields are properly added to the select statement """ ss = SelectStatement('table', ['f1', 'f2']) - self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) - self.assertTrue(str(ss).startswith('SELECT "f1", "f2"'), str(ss)) + assert str(ss).startswith('SELECT "f1", "f2"'), str(ss) + assert str(ss).startswith('SELECT "f1", "f2"'), str(ss) def test_none_fields_rendering(self): """ tests that a '*' is added if no fields are passed in """ ss = SelectStatement('table') - self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) - self.assertTrue(str(ss).startswith('SELECT *'), str(ss)) + assert str(ss).startswith('SELECT *'), str(ss) + assert str(ss).startswith('SELECT *'), str(ss) def test_table_rendering(self): ss = SelectStatement('table') - self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) - self.assertTrue(str(ss).startswith('SELECT * FROM table'), str(ss)) + assert str(ss).startswith('SELECT * FROM table'), str(ss) + assert str(ss).startswith('SELECT * FROM table'), str(ss) def test_where_clause_rendering(self): ss = SelectStatement('table') ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s', str(ss)) + assert str(ss) == 'SELECT * FROM table WHERE "a" = %(0)s', str(ss) def test_count(self): ss = SelectStatement('table', count=True, limit=10, order_by='d') ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(str(ss), 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', str(ss)) - self.assertIn('LIMIT', str(ss)) - self.assertNotIn('ORDER', str(ss)) + assert str(ss) == 'SELECT COUNT(*) FROM table WHERE "a" = %(0)s LIMIT 10', str(ss) + assert 'LIMIT' in str(ss) + assert 'ORDER' not in str(ss) def test_distinct(self): ss = SelectStatement('table', distinct_fields=['field2']) ss.add_where(Column(db_field='field1'), EqualsOperator(), 'b') - self.assertEqual(str(ss), 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', str(ss)) + assert str(ss) == 'SELECT DISTINCT "field2" FROM table WHERE "field1" = %(0)s', str(ss) ss = SelectStatement('table', distinct_fields=['field1', 'field2']) - self.assertEqual(str(ss), 'SELECT DISTINCT "field1", "field2" FROM table') + assert str(ss) == 'SELECT DISTINCT "field1", "field2" FROM table' ss = SelectStatement('table', distinct_fields=['field1'], count=True) - self.assertEqual(str(ss), 'SELECT DISTINCT COUNT("field1") FROM table') + assert str(ss) == 'SELECT DISTINCT COUNT("field1") FROM table' def test_context(self): ss = SelectStatement('table') ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(ss.get_context(), {'0': 'b'}) + assert ss.get_context() == {'0': 'b'} def test_context_id_update(self): """ tests that the right things happen the the context id """ ss = SelectStatement('table') ss.add_where(Column(db_field='a'), EqualsOperator(), 'b') - self.assertEqual(ss.get_context(), {'0': 'b'}) - self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(0)s') + assert ss.get_context() == {'0': 'b'} + assert str(ss) == 'SELECT * FROM table WHERE "a" = %(0)s' ss.update_context_id(5) - self.assertEqual(ss.get_context(), {'5': 'b'}) - self.assertEqual(str(ss), 'SELECT * FROM table WHERE "a" = %(5)s') + assert ss.get_context() == {'5': 'b'} + assert str(ss) == 'SELECT * FROM table WHERE "a" = %(5)s' def test_additional_rendering(self): ss = SelectStatement( @@ -89,19 +89,19 @@ def test_additional_rendering(self): allow_filtering=True ) qstr = str(ss) - self.assertIn('LIMIT 15', qstr) - self.assertIn('ORDER BY x, y', qstr) - self.assertIn('ALLOW FILTERING', qstr) + assert 'LIMIT 15' in qstr + assert 'ORDER BY x, y' in qstr + assert 'ALLOW FILTERING' in qstr def test_limit_rendering(self): ss = SelectStatement('table', None, limit=10) qstr = str(ss) - self.assertIn('LIMIT 10', qstr) + assert 'LIMIT 10' in qstr ss = SelectStatement('table', None, limit=0) qstr = str(ss) - self.assertNotIn('LIMIT', qstr) + assert 'LIMIT' not in qstr ss = SelectStatement('table', None, limit=None) qstr = str(ss) - self.assertNotIn('LIMIT', qstr) + assert 'LIMIT' not in qstr diff --git a/tests/integration/cqlengine/statements/test_update_statement.py b/tests/integration/cqlengine/statements/test_update_statement.py index 4429625bf4..6529b73558 100644 --- a/tests/integration/cqlengine/statements/test_update_statement.py +++ b/tests/integration/cqlengine/statements/test_update_statement.py @@ -25,25 +25,25 @@ class UpdateStatementTests(unittest.TestCase): def test_table_rendering(self): """ tests that fields are properly added to the select statement """ us = UpdateStatement('table') - self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) - self.assertTrue(str(us).startswith('UPDATE table SET'), str(us)) + assert str(us).startswith('UPDATE table SET'), str(us) + assert str(us).startswith('UPDATE table SET'), str(us) def test_rendering(self): us = UpdateStatement('table') us.add_assignment(Column(db_field='a'), 'b') us.add_assignment(Column(db_field='c'), 'd') us.add_where(Column(db_field='a'), EqualsOperator(), 'x') - self.assertEqual(str(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', str(us)) + assert str(us) == 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s', str(us) us.add_where(Column(db_field='a'), NotEqualsOperator(), 'y') - self.assertEqual(str(us), 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s AND "a" != %(3)s', str(us)) + assert str(us) == 'UPDATE table SET "a" = %(0)s, "c" = %(1)s WHERE "a" = %(2)s AND "a" != %(3)s', str(us) def test_context(self): us = UpdateStatement('table') us.add_assignment(Column(db_field='a'), 'b') us.add_assignment(Column(db_field='c'), 'd') us.add_where(Column(db_field='a'), EqualsOperator(), 'x') - self.assertEqual(us.get_context(), {'0': 'b', '1': 'd', '2': 'x'}) + assert us.get_context() == {'0': 'b', '1': 'd', '2': 'x'} def test_context_update(self): us = UpdateStatement('table') @@ -51,36 +51,36 @@ def test_context_update(self): us.add_assignment(Column(db_field='c'), 'd') us.add_where(Column(db_field='a'), EqualsOperator(), 'x') us.update_context_id(3) - self.assertEqual(str(us), 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s') - self.assertEqual(us.get_context(), {'4': 'b', '5': 'd', '3': 'x'}) + assert str(us) == 'UPDATE table SET "a" = %(4)s, "c" = %(5)s WHERE "a" = %(3)s' + assert us.get_context() == {'4': 'b', '5': 'd', '3': 'x'} def test_additional_rendering(self): us = UpdateStatement('table', ttl=60) us.add_assignment(Column(db_field='a'), 'b') us.add_where(Column(db_field='a'), EqualsOperator(), 'x') - self.assertIn('USING TTL 60', str(us)) + assert 'USING TTL 60' in str(us) def test_update_set_add(self): us = UpdateStatement('table') us.add_update(Set(Text, db_field='a'), set((1,)), 'add') - self.assertEqual(str(us), 'UPDATE table SET "a" = "a" + %(0)s') + assert str(us) == 'UPDATE table SET "a" = "a" + %(0)s' def test_update_empty_set_add_does_not_assign(self): us = UpdateStatement('table') us.add_update(Set(Text, db_field='a'), set(), 'add') - self.assertFalse(us.assignments) + assert not us.assignments def test_update_empty_set_removal_does_not_assign(self): us = UpdateStatement('table') us.add_update(Set(Text, db_field='a'), set(), 'remove') - self.assertFalse(us.assignments) + assert not us.assignments def test_update_list_prepend_with_empty_list(self): us = UpdateStatement('table') us.add_update(List(Text, db_field='a'), [], 'prepend') - self.assertFalse(us.assignments) + assert not us.assignments def test_update_list_append_with_empty_list(self): us = UpdateStatement('table') us.add_update(List(Text, db_field='a'), [], 'append') - self.assertFalse(us.assignments) + assert not us.assignments diff --git a/tests/integration/cqlengine/statements/test_where_clause.py b/tests/integration/cqlengine/statements/test_where_clause.py index 0090fa0123..8ac2536a19 100644 --- a/tests/integration/cqlengine/statements/test_where_clause.py +++ b/tests/integration/cqlengine/statements/test_where_clause.py @@ -15,13 +15,14 @@ from cassandra.cqlengine.operators import EqualsOperator from cassandra.cqlengine.statements import StatementException, WhereClause +import pytest class TestWhereClause(unittest.TestCase): def test_operator_check(self): """ tests that creating a where statement with a non BaseWhereOperator object fails """ - with self.assertRaises(StatementException): + with pytest.raises(StatementException): WhereClause('a', 'b', 'c') def test_where_clause_rendering(self): @@ -29,8 +30,8 @@ def test_where_clause_rendering(self): wc = WhereClause('a', EqualsOperator(), 'c') wc.set_context_id(5) - self.assertEqual('"a" = %(5)s', str(wc), str(wc)) - self.assertEqual('"a" = %(5)s', str(wc), type(wc)) + assert '"a" = %(5)s' == str(wc), str(wc) + assert '"a" = %(5)s' == str(wc), type(wc) def test_equality_method(self): """ tests that 2 identical where clauses evaluate as == """ diff --git a/tests/integration/cqlengine/test_batch_query.py b/tests/integration/cqlengine/test_batch_query.py index 399bee6202..cd6bf0fe93 100644 --- a/tests/integration/cqlengine/test_batch_query.py +++ b/tests/integration/cqlengine/test_batch_query.py @@ -20,6 +20,7 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine.query import BatchQuery from tests.integration.cqlengine.base import BaseCassEngTestCase +from tests.util import assertRegex from unittest.mock import patch @@ -56,7 +57,7 @@ def test_insert_success_case(self): b = BatchQuery() TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=2, count=3, text='4') - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) b.execute() @@ -73,12 +74,12 @@ def test_update_success_case(self): inst.batch(b).save() inst2 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) - self.assertEqual(inst2.count, 3) + assert inst2.count == 3 b.execute() inst3 = TestMultiKeyModel.get(partition=self.pkey, cluster=2) - self.assertEqual(inst3.count, 4) + assert inst3.count == 4 def test_delete_success_case(self): @@ -92,7 +93,7 @@ def test_delete_success_case(self): b.execute() - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=2) def test_context_manager(self): @@ -102,7 +103,7 @@ def test_context_manager(self): TestMultiKeyModel.batch(b).create(partition=self.pkey, cluster=i, count=3, text='4') for i in range(5): - with self.assertRaises(TestMultiKeyModel.DoesNotExist): + with pytest.raises(TestMultiKeyModel.DoesNotExist): TestMultiKeyModel.get(partition=self.pkey, cluster=i) for i in range(5): @@ -116,9 +117,9 @@ def test_bulk_delete_success_case(self): with BatchQuery() as b: TestMultiKeyModel.objects.batch(b).filter(partition=0).delete() - self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 5) + assert TestMultiKeyModel.filter(partition=0).count() == 5 - self.assertEqual(TestMultiKeyModel.filter(partition=0).count(), 0) + assert TestMultiKeyModel.filter(partition=0).count() == 0 #cleanup for m in TestMultiKeyModel.all(): m.delete() @@ -146,11 +147,11 @@ def my_callback(*args, **kwargs): batch.add_callback(my_callback, 2, named_arg='value') batch.add_callback(my_callback, 1, 3) - self.assertEqual(batch._callbacks, [ + assert batch._callbacks == [ (my_callback, (), {}), (my_callback, (2,), {'named_arg':'value'}), (my_callback, (1, 3), {}) - ]) + ] def test_callbacks_properly_execute_callables_and_tuples(self): @@ -166,8 +167,8 @@ def my_callback(*args, **kwargs): batch.execute() - self.assertEqual(len(call_history), 2) - self.assertEqual([(), ('more', 'args')], call_history) + assert len(call_history) == 2 + assert [(), ('more', 'args')] == call_history def test_callbacks_tied_to_execute(self): """Batch callbacks should NOT fire if batch is not executed in context manager mode""" @@ -179,12 +180,12 @@ def my_callback(*args, **kwargs): with BatchQuery() as batch: batch.add_callback(my_callback) - self.assertEqual(len(call_history), 1) + assert len(call_history) == 1 class SomeError(Exception): pass - with self.assertRaises(SomeError): + with pytest.raises(SomeError): with BatchQuery() as batch: batch.add_callback(my_callback) # this error bubbling up through context manager @@ -192,17 +193,17 @@ class SomeError(Exception): raise SomeError # still same call history. Nothing added - self.assertEqual(len(call_history), 1) + assert len(call_history) == 1 # but if execute ran, even with an error bubbling through # the callbacks also would have fired - with self.assertRaises(SomeError): + with pytest.raises(SomeError): with BatchQuery(execute_on_exception=True) as batch: batch.add_callback(my_callback) raise SomeError # updated call history - self.assertEqual(len(call_history), 2) + assert len(call_history) == 2 def test_callbacks_work_multiple_times(self): """ @@ -224,8 +225,8 @@ def my_callback(*args, **kwargs): batch.add_callback(my_callback) batch.execute() batch.execute() - self.assertEqual(len(w), 2) # package filter setup to warn always - self.assertRegex(str(w[0].message), r"^Batch.*multiple.*") + assert len(w) == 2 # package filter setup to warn always + assertRegex(str(w[0].message), r"^Batch.*multiple.*") def test_disable_multiple_callback_warning(self): """ @@ -250,4 +251,4 @@ def my_callback(*args, **kwargs): batch.add_callback(my_callback) batch.execute() batch.execute() - self.assertFalse(w) + assert not w diff --git a/tests/integration/cqlengine/test_connections.py b/tests/integration/cqlengine/test_connections.py index 32db143088..612255bdc5 100644 --- a/tests/integration/cqlengine/test_connections.py +++ b/tests/integration/cqlengine/test_connections.py @@ -23,6 +23,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration.cqlengine.query import test_queryset from tests.integration import local, CASSANDRA_IP, TestCluster +import pytest class TestModel(Model): @@ -98,7 +99,7 @@ def test_context_connection_priority(self): # ContextQuery connection should have priority over default one with ContextQuery(TestModel, connection='fake_cluster') as tm: - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): tm.objects.create(partition=1, cluster=1) # Explicit connection should have priority over ContextQuery one @@ -110,7 +111,7 @@ def test_context_connection_priority(self): # No model connection and an invalid default connection with ContextQuery(TestModel) as tm: - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): tm.objects.create(partition=1, cluster=1) def test_context_connection_with_keyspace(self): @@ -126,7 +127,7 @@ def test_context_connection_with_keyspace(self): # ks2 doesn't exist with ContextQuery(TestModel, connection='cluster', keyspace='ks2') as tm: - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): tm.objects.create(partition=1, cluster=1) @@ -166,7 +167,7 @@ def test_create_drop_keyspace(self): """ # No connection (default is fake) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): create_keyspace_simple(self.keyspaces[0], 1) # Explicit connections @@ -190,7 +191,7 @@ def test_create_drop_table(self): create_keyspace_simple(ks, 1, connections=self.conns) # No connection (default is fake) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): sync_table(TestModel) # Explicit connections @@ -205,7 +206,7 @@ def test_create_drop_table(self): TestModel.__connection__ = None # No connection (default is fake) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): drop_table(TestModel) # Model connection @@ -230,7 +231,7 @@ def test_connection_creation_from_session(self): session = cluster.connect() connection_name = 'from_session' conn.register_connection(connection_name, session=session) - self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + assert conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP) is not None self.addCleanup(conn.unregister_connection, connection_name) cluster.shutdown() @@ -245,7 +246,7 @@ def test_connection_from_hosts(self): """ connection_name = 'from_hosts' conn.register_connection(connection_name, hosts=[CASSANDRA_IP]) - self.assertIsNotNone(conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP)) + assert conn.get_connection(connection_name).cluster.metadata.get_host(CASSANDRA_IP) is not None self.addCleanup(conn.unregister_connection, connection_name) def test_connection_param_validation(self): @@ -259,15 +260,15 @@ def test_connection_param_validation(self): """ cluster = TestCluster() session = cluster.connect() - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): conn.register_connection("bad_coonection1", session=session, consistency="not_null") - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): conn.register_connection("bad_coonection2", session=session, lazy_connect="not_null") - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): conn.register_connection("bad_coonection3", session=session, retry_connect="not_null") - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): conn.register_connection("bad_coonection4", session=session, cluster_options="not_null") - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): conn.register_connection("bad_coonection5", hosts="not_null", session=session) cluster.shutdown() @@ -318,7 +319,7 @@ def test_basic_batch_query(self): """ # No connection with a QuerySet (default is a fake one) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): with BatchQuery() as b: TestModel.objects.batch(b).create(partition=1, cluster=1) @@ -332,7 +333,7 @@ def test_basic_batch_query(self): obj.__connection__ = None # No connection with a model (default is a fake one) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): with BatchQuery() as b: obj.count = 2 obj.batch(b).save() @@ -357,7 +358,7 @@ def test_batch_query_different_connection(self): TestModel.__connection__ = 'cluster' AnotherTestModel.__connection__ = 'cluster2' - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery() as b: TestModel.objects.batch(b).create(partition=1, cluster=1) AnotherTestModel.objects.batch(b).create(partition=1, cluster=1) @@ -380,7 +381,7 @@ def test_batch_query_different_connection(self): obj1.count = 4 obj2.count = 4 - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery() as b: obj1.batch(b).save() obj2.batch(b).save() @@ -396,11 +397,11 @@ def test_batch_query_connection_override(self): @test_category object_mapper """ - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery(connection='cluster') as b: TestModel.batch(b).using(connection='test').save() - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery(connection='cluster') as b: TestModel.using(connection='test').batch(b).save() @@ -408,11 +409,11 @@ def test_batch_query_connection_override(self): obj1 = tm.objects.get(partition=1, cluster=1) obj1.__connection__ = None - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery(connection='cluster') as b: obj1.using(connection='test').batch(b).save() - with self.assertRaises(CQLEngineException): + with pytest.raises(CQLEngineException): with BatchQuery(connection='cluster') as b: obj1.batch(b).using(connection='test').save() @@ -470,26 +471,26 @@ def test_keyspace(self): tm.objects.using(keyspace='ks2').create(partition=1, cluster=1) tm.objects.using(keyspace='ks2').create(partition=2, cluster=2) - with self.assertRaises(TestModel.DoesNotExist): + with pytest.raises(TestModel.DoesNotExist): tm.objects.get(partition=1, cluster=1) # default keyspace ks1 obj1 = tm.objects.using(keyspace='ks2').get(partition=1, cluster=1) obj1.count = 2 obj1.save() - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): TestModel.objects.using(keyspace='ks2').get(partition=1, cluster=1) obj2 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=1, cluster=1) - self.assertEqual(obj2.count, 2) + assert obj2.count == 2 # Update test TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').update(count=5) obj3 = TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) - self.assertEqual(obj3.count, 5) + assert obj3.count == 5 TestModel.objects(partition=2, cluster=2).using(connection='cluster', keyspace='ks2').delete() - with self.assertRaises(TestModel.DoesNotExist): + with pytest.raises(TestModel.DoesNotExist): TestModel.objects.using(connection='cluster', keyspace='ks2').get(partition=2, cluster=2) def test_connection(self): @@ -505,20 +506,20 @@ def test_connection(self): self._reset_data() # Model class - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): TestModel.objects.create(partition=1, cluster=1) TestModel.objects.using(connection='cluster').create(partition=1, cluster=1) TestModel.objects(partition=1, cluster=1).using(connection='cluster').update(count=2) obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) - self.assertEqual(obj1.count, 2) + assert obj1.count == 2 obj1.using(connection='cluster').update(count=5) obj1 = TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) - self.assertEqual(obj1.count, 5) + assert obj1.count == 5 obj1.using(connection='cluster').delete() - with self.assertRaises(TestModel.DoesNotExist): + with pytest.raises(TestModel.DoesNotExist): TestModel.objects.using(connection='cluster').get(partition=1, cluster=1) diff --git a/tests/integration/cqlengine/test_consistency.py b/tests/integration/cqlengine/test_consistency.py index a93bbee1ae..dedbe01fdf 100644 --- a/tests/integration/cqlengine/test_consistency.py +++ b/tests/integration/cqlengine/test_consistency.py @@ -53,11 +53,11 @@ def test_create_uses_consistency(self): qs.create(text="i am not fault tolerant this way") args = m.call_args - self.assertEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL == args[0][0].consistency_level def test_queryset_is_returned_on_create(self): qs = TestConsistencyModel.consistency(CL.ALL) - self.assertTrue(isinstance(qs, TestConsistencyModel.__queryset__), type(qs)) + assert isinstance(qs, TestConsistencyModel.__queryset__), type(qs) def test_update_uses_consistency(self): t = TestConsistencyModel.create(text="bacon and eggs") @@ -67,7 +67,7 @@ def test_update_uses_consistency(self): t.consistency(CL.ALL).save() args = m.call_args - self.assertEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL == args[0][0].consistency_level def test_batch_consistency(self): @@ -77,14 +77,14 @@ def test_batch_consistency(self): args = m.call_args - self.assertEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL == args[0][0].consistency_level with mock.patch.object(self.session, 'execute') as m: with BatchQuery() as b: TestConsistencyModel.batch(b).create(text="monkey") args = m.call_args - self.assertNotEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL != args[0][0].consistency_level def test_blind_update(self): t = TestConsistencyModel.create(text="bacon and eggs") @@ -95,7 +95,7 @@ def test_blind_update(self): TestConsistencyModel.objects(id=uid).consistency(CL.ALL).update(text="grilled cheese") args = m.call_args - self.assertEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL == args[0][0].consistency_level def test_delete(self): # ensures we always carry consistency through on delete statements @@ -110,13 +110,13 @@ def test_delete(self): TestConsistencyModel.objects(id=uid).consistency(CL.ALL).delete() args = m.call_args - self.assertEqual(CL.ALL, args[0][0].consistency_level) + assert CL.ALL == args[0][0].consistency_level def test_default_consistency(self): # verify global assumed default - self.assertEqual(Session._default_consistency_level, ConsistencyLevel.LOCAL_ONE) + assert Session._default_consistency_level == ConsistencyLevel.LOCAL_ONE # verify that this session default is set according to connection.setup # assumes tests/cqlengine/__init__ setup uses CL.ONE session = connection.get_session() - self.assertEqual(session.default_consistency_level, ConsistencyLevel.ONE) + assert session.default_consistency_level == ConsistencyLevel.ONE diff --git a/tests/integration/cqlengine/test_context_query.py b/tests/integration/cqlengine/test_context_query.py index 8ced5f0f49..a922806dcf 100644 --- a/tests/integration/cqlengine/test_context_query.py +++ b/tests/integration/cqlengine/test_context_query.py @@ -17,6 +17,7 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine.query import ContextQuery from tests.integration.cqlengine.base import BaseCassEngTestCase +import pytest class TestModel(Model): @@ -68,9 +69,9 @@ def test_context_manager(self): # model keyspace write/read for ks in self.KEYSPACES: with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(tm.__keyspace__, ks) + assert tm.__keyspace__ == ks - self.assertEqual(TestModel._get_keyspace(), 'ks1') + assert TestModel._get_keyspace() == 'ks1' def test_default_keyspace(self): """ @@ -87,14 +88,14 @@ def test_default_keyspace(self): TestModel.objects.create(partition=i, cluster=i) with ContextQuery(TestModel) as tm: - self.assertEqual(5, len(tm.objects.all())) + assert 5 == len(tm.objects.all()) with ContextQuery(TestModel, keyspace='ks1') as tm: - self.assertEqual(5, len(tm.objects.all())) + assert 5 == len(tm.objects.all()) for ks in self.KEYSPACES[1:]: with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(0, len(tm.objects.all())) + assert 0 == len(tm.objects.all()) def test_context_keyspace(self): """ @@ -111,20 +112,20 @@ def test_context_keyspace(self): tm.objects.create(partition=i, cluster=i) with ContextQuery(TestModel, keyspace='ks4') as tm: - self.assertEqual(5, len(tm.objects.all())) + assert 5 == len(tm.objects.all()) - self.assertEqual(0, len(TestModel.objects.all())) + assert 0 == len(TestModel.objects.all()) for ks in self.KEYSPACES[:2]: with ContextQuery(TestModel, keyspace=ks) as tm: - self.assertEqual(0, len(tm.objects.all())) + assert 0 == len(tm.objects.all()) # simple data update with ContextQuery(TestModel, keyspace='ks4') as tm: obj = tm.objects.get(partition=1) obj.update(count=42) - self.assertEqual(42, tm.objects.get(partition=1).count) + assert 42 == tm.objects.get(partition=1).count def test_context_multiple_models(self): """ @@ -139,9 +140,9 @@ def test_context_multiple_models(self): with ContextQuery(TestModel, TestModel, keyspace='ks4') as (tm1, tm2): - self.assertNotEqual(tm1, tm2) - self.assertEqual(tm1.__keyspace__, 'ks4') - self.assertEqual(tm2.__keyspace__, 'ks4') + assert tm1 != tm2 + assert tm1.__keyspace__ == 'ks4' + assert tm2.__keyspace__ == 'ks4' def test_context_invalid_parameters(self): """ @@ -154,22 +155,22 @@ def test_context_invalid_parameters(self): @test_category query """ - with self.assertRaises(ValueError): + with pytest.raises(ValueError): with ContextQuery(keyspace='ks2'): pass - with self.assertRaises(ValueError): + with pytest.raises(ValueError): with ContextQuery(42) as tm: pass - with self.assertRaises(ValueError): + with pytest.raises(ValueError): with ContextQuery(TestModel, 42): pass - with self.assertRaises(ValueError): + with pytest.raises(ValueError): with ContextQuery(TestModel, unknown_param=42): pass - with self.assertRaises(ValueError): + with pytest.raises(ValueError): with ContextQuery(TestModel, keyspace='ks2', unknown_param=42): pass \ No newline at end of file diff --git a/tests/integration/cqlengine/test_ifexists.py b/tests/integration/cqlengine/test_ifexists.py index 9e8e5d5424..6c2ff437ab 100644 --- a/tests/integration/cqlengine/test_ifexists.py +++ b/tests/integration/cqlengine/test_ifexists.py @@ -22,6 +22,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration import PROTOCOL_VERSION +import pytest class TestIfExistsModel(Model): @@ -95,25 +96,25 @@ def test_update_if_exists(self): m.text = 'changed' m.if_exists().update() m = TestIfExistsModel.get(id=id) - self.assertEqual(m.text, 'changed') + assert m.text == 'changed' # save() m.text = 'changed_again' m.if_exists().save() m = TestIfExistsModel.get(id=id) - self.assertEqual(m.text, 'changed_again') + assert m.text == 'changed_again' m = TestIfExistsModel(id=uuid4(), count=44) # do not exists - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: m.if_exists().update() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False # queryset update - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: TestIfExistsModel.objects(id=uuid4()).if_exists().update(count=8) - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_update_if_exists_success(self): @@ -135,19 +136,19 @@ def test_batch_update_if_exists_success(self): m.text = '111111111' m.batch(b).if_exists().update() - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: with BatchQuery() as b: m = TestIfExistsModel(id=uuid4(), count=42) # Doesn't exist m.batch(b).if_exists().update() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False q = TestIfExistsModel.objects(id=id) - self.assertEqual(len(q), 1) + assert len(q) == 1 tm = q.first() - self.assertEqual(tm.count, 8) - self.assertEqual(tm.text, '111111111') + assert tm.count == 8 + assert tm.text == '111111111' @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_mixed_update_if_exists_success(self): @@ -162,14 +163,14 @@ def test_batch_mixed_update_if_exists_success(self): """ m = TestIfExistsModel2.create(id=1, count=8, text='123456789') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: with BatchQuery() as b: m.text = '111111112' m.batch(b).if_exists().update() # Does exist n = TestIfExistsModel2(id=1, count=10, text="Failure") # Doesn't exist n.batch(b).if_exists().update() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_delete_if_exists(self): @@ -188,19 +189,19 @@ def test_delete_if_exists(self): m = TestIfExistsModel.create(id=id, count=8, text='123456789') m.if_exists().delete() q = TestIfExistsModel.objects(id=id) - self.assertEqual(len(q), 0) + assert len(q) == 0 m = TestIfExistsModel(id=uuid4(), count=44) # do not exists - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: m.if_exists().delete() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False # queryset delete - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: TestIfExistsModel.objects(id=uuid4()).if_exists().delete() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_if_exists_success(self): @@ -222,14 +223,14 @@ def test_batch_delete_if_exists_success(self): m.batch(b).if_exists().delete() q = TestIfExistsModel.objects(id=id) - self.assertEqual(len(q), 0) + assert len(q) == 0 - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: with BatchQuery() as b: m = TestIfExistsModel(id=uuid4(), count=42) # Doesn't exist m.batch(b).if_exists().delete() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_delete_mixed(self): @@ -245,15 +246,15 @@ def test_batch_delete_mixed(self): m = TestIfExistsModel2.create(id=3, count=8, text='123456789') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: with BatchQuery() as b: m.batch(b).if_exists().delete() # Does exist n = TestIfExistsModel2(id=3, count=42, text='1111111') # Doesn't exist n.batch(b).if_exists().delete() - self.assertEqual(assertion.exception.existing.get('[applied]'), False) + assert assertion.value.existing.get('[applied]') == False q = TestIfExistsModel2.objects(id=3, count=8) - self.assertEqual(len(q), 1) + assert len(q) == 1 class IfExistsQueryTest(BaseIfExistsTest): @@ -264,7 +265,7 @@ def test_if_exists_included_on_queryset_update(self): TestIfExistsModel.objects(id=uuid4()).if_exists().update(count=42) query = m.call_args[0][0].query_string - self.assertIn("IF EXISTS", query) + assert "IF EXISTS" in query def test_if_exists_included_on_update(self): """ tests that if_exists on models update works as expected """ @@ -273,7 +274,7 @@ def test_if_exists_included_on_update(self): TestIfExistsModel(id=uuid4()).if_exists().update(count=8) query = m.call_args[0][0].query_string - self.assertIn("IF EXISTS", query) + assert "IF EXISTS" in query def test_if_exists_included_on_delete(self): """ tests that if_exists on models delete works as expected """ @@ -282,7 +283,7 @@ def test_if_exists_included_on_delete(self): TestIfExistsModel(id=uuid4()).if_exists().delete() query = m.call_args[0][0].query_string - self.assertIn("IF EXISTS", query) + assert "IF EXISTS" in query class IfExistWithCounterTest(BaseIfExistsWithCounterTest): @@ -298,6 +299,5 @@ def test_instance_raise_exception(self): @test_category object_mapper """ id = uuid4() - with self.assertRaises(IfExistsWithCounterColumn): + with pytest.raises(IfExistsWithCounterColumn): TestIfExistsWithCounterModel.if_exists() - diff --git a/tests/integration/cqlengine/test_ifnotexists.py b/tests/integration/cqlengine/test_ifnotexists.py index 013d4e245e..6a1dd9d4bc 100644 --- a/tests/integration/cqlengine/test_ifnotexists.py +++ b/tests/integration/cqlengine/test_ifnotexists.py @@ -22,6 +22,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration import PROTOCOL_VERSION +import pytest class TestIfNotExistsModel(Model): __test__ = False @@ -80,25 +81,25 @@ def test_insert_if_not_exists(self): TestIfNotExistsModel.create(id=id, count=8, text='123456789') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException): TestIfNotExistsModel.if_not_exists().create(id=id, count=9, text='111111111111') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: TestIfNotExistsModel.objects(count=9, text='111111111111').if_not_exists().create(id=id) - self.assertEqual(assertion.exception.existing, { + assert assertion.value.existing == { 'count': 8, 'id': id, 'text': '123456789', '[applied]': False, - }) + } q = TestIfNotExistsModel.objects(id=id) - self.assertEqual(len(q), 1) + assert len(q) == 1 tm = q.first() - self.assertEqual(tm.count, 8) - self.assertEqual(tm.text, '123456789') + assert tm.count == 8 + assert tm.text == '123456789' @unittest.skipUnless(PROTOCOL_VERSION >= 2, "only runs against the cql3 protocol v2.0") def test_batch_insert_if_not_exists(self): @@ -111,22 +112,22 @@ def test_batch_insert_if_not_exists(self): b = BatchQuery() TestIfNotExistsModel.batch(b).if_not_exists().create(id=id, count=9, text='111111111111') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: b.execute() - self.assertEqual(assertion.exception.existing, { + assert assertion.value.existing == { 'count': 8, 'id': id, 'text': '123456789', '[applied]': False, - }) + } q = TestIfNotExistsModel.objects(id=id) - self.assertEqual(len(q), 1) + assert len(q) == 1 tm = q.first() - self.assertEqual(tm.count, 8) - self.assertEqual(tm.text, '123456789') + assert tm.count == 8 + assert tm.text == '123456789' class IfNotExistsModelTest(BaseIfNotExistsTest): @@ -138,7 +139,7 @@ def test_if_not_exists_included_on_create(self): TestIfNotExistsModel.if_not_exists().create(count=8) query = m.call_args[0][0].query_string - self.assertIn("IF NOT EXISTS", query) + assert "IF NOT EXISTS" in query def test_if_not_exists_included_on_save(self): """ tests if we correctly put 'IF NOT EXISTS' for insert statement """ @@ -148,12 +149,12 @@ def test_if_not_exists_included_on_save(self): tm.if_not_exists().save() query = m.call_args[0][0].query_string - self.assertIn("IF NOT EXISTS", query) + assert "IF NOT EXISTS" in query def test_queryset_is_returned_on_class(self): """ ensure we get a queryset description back """ qs = TestIfNotExistsModel.if_not_exists() - self.assertTrue(isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs)) + assert isinstance(qs, TestIfNotExistsModel.__queryset__), type(qs) def test_batch_if_not_exists(self): """ ensure 'IF NOT EXISTS' exists in statement when in batch """ @@ -161,7 +162,7 @@ def test_batch_if_not_exists(self): with BatchQuery() as b: TestIfNotExistsModel.batch(b).if_not_exists().create(count=8) - self.assertIn("IF NOT EXISTS", m.call_args[0][0].query_string) + assert "IF NOT EXISTS" in m.call_args[0][0].query_string class IfNotExistsInstanceTest(BaseIfNotExistsTest): @@ -174,7 +175,7 @@ def test_instance_is_returned(self): o = TestIfNotExistsModel.create(text="whatever") o.text = "new stuff" o = o.if_not_exists() - self.assertEqual(True, o._if_not_exists) + assert True == o._if_not_exists def test_if_not_exists_is_not_include_with_query_on_update(self): """ @@ -188,7 +189,7 @@ def test_if_not_exists_is_not_include_with_query_on_update(self): o.save() query = m.call_args[0][0].query_string - self.assertNotIn("IF NOT EXIST", query) + assert "IF NOT EXIST" not in query class IfNotExistWithCounterTest(BaseIfNotExistsWithCounterTest): @@ -198,6 +199,5 @@ def test_instance_raise_exception(self): if_not_exists on table with counter column """ id = uuid4() - with self.assertRaises(IfNotExistsWithCounterColumn): + with pytest.raises(IfNotExistsWithCounterColumn): TestIfNotExistsWithCounterModel.if_not_exists() - diff --git a/tests/integration/cqlengine/test_lwt_conditional.py b/tests/integration/cqlengine/test_lwt_conditional.py index f5ec89dd2e..f8d9d01035 100644 --- a/tests/integration/cqlengine/test_lwt_conditional.py +++ b/tests/integration/cqlengine/test_lwt_conditional.py @@ -23,6 +23,7 @@ from tests.integration.cqlengine.base import BaseCassEngTestCase from tests.integration import greaterthancass20 +import pytest class TestConditionalModel(Model): @@ -62,7 +63,7 @@ def test_update_using_conditional(self): t.iff(text='blah blah').save() args = m.call_args - self.assertIn('IF "text" = %(0)s', args[0][0].query_string) + assert 'IF "text" = %(0)s' in args[0][0].query_string def test_update_conditional_success(self): t = TestConditionalModel.if_not_exists().create(text='blah blah', count=5) @@ -71,21 +72,21 @@ def test_update_conditional_success(self): t.iff(text='blah blah').save() updated = TestConditionalModel.objects(id=id).first() - self.assertEqual(updated.count, 5) - self.assertEqual(updated.text, 'new blah') + assert updated.count == 5 + assert updated.text == 'new blah' def test_update_failure(self): t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'new blah' t = t.iff(text='something wrong') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: t.save() - self.assertEqual(assertion.exception.existing, { + assert assertion.value.existing == { 'text': 'blah blah', '[applied]': False, - }) + } def test_blind_update(self): t = TestConditionalModel.if_not_exists().create(text='blah blah') @@ -96,27 +97,27 @@ def test_blind_update(self): TestConditionalModel.objects(id=uid).iff(text='blah blah').update(text='oh hey der') args = m.call_args - self.assertIn('IF "text" = %(1)s', args[0][0].query_string) + assert 'IF "text" = %(1)s' in args[0][0].query_string def test_blind_update_fail(self): t = TestConditionalModel.if_not_exists().create(text='blah blah') t.text = 'something else' uid = t.id qs = TestConditionalModel.objects(id=uid).iff(text='Not dis!') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: qs.update(text='this will never work') - self.assertEqual(assertion.exception.existing, { + assert assertion.value.existing == { 'text': 'blah blah', '[applied]': False, - }) + } def test_conditional_clause(self): tc = ConditionalClause('some_value', 23) tc.set_context_id(3) - self.assertEqual('"some_value" = %(3)s', str(tc)) - self.assertEqual('"some_value" = %(3)s', str(tc)) + assert '"some_value" = %(3)s' == str(tc) + assert '"some_value" = %(3)s' == str(tc) def test_batch_update_conditional(self): t = TestConditionalModel.if_not_exists().create(text='something', count=5) @@ -125,21 +126,21 @@ def test_batch_update_conditional(self): t.batch(b).iff(count=5).update(text='something else') updated = TestConditionalModel.objects(id=id).first() - self.assertEqual(updated.text, 'something else') + assert updated.text == 'something else' b = BatchQuery() updated.batch(b).iff(count=6).update(text='and another thing') - with self.assertRaises(LWTException) as assertion: + with pytest.raises(LWTException) as assertion: b.execute() - self.assertEqual(assertion.exception.existing, { + assert assertion.value.existing == { 'id': id, 'count': 5, '[applied]': False, - }) + } updated = TestConditionalModel.objects(id=id).first() - self.assertEqual(updated.text, 'something else') + assert updated.text == 'something else' @unittest.skip("Skipping until PYTHON-943 is resolved") def test_batch_update_conditional_several_rows(self): @@ -155,7 +156,7 @@ def test_batch_update_conditional_several_rows(self): TestUpdateModel.batch(b).if_not_exists().create(partition=1, cluster=3, value=5, text='something else') # The response will be more than two rows because two of the inserts will fail - with self.assertRaises(LWTException): + with pytest.raises(LWTException): b.execute() first_row.delete() @@ -166,21 +167,21 @@ def test_batch_update_conditional_several_rows(self): def test_delete_conditional(self): # DML path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): t.iff(count=9999).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + assert TestConditionalModel.objects(id=t.id).count() == 1 t.iff(count=5).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + assert TestConditionalModel.objects(id=t.id).count() == 0 # QuerySet path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): TestConditionalModel.objects(id=t.id).iff(count=9999).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) + assert TestConditionalModel.objects(id=t.id).count() == 1 TestConditionalModel.objects(id=t.id).iff(count=5).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + assert TestConditionalModel.objects(id=t.id).count() == 0 def test_delete_lwt_ne(self): """ @@ -195,19 +196,19 @@ def test_delete_lwt_ne(self): # DML path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): t.iff(count__ne=5).delete() t.iff(count__ne=2).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + assert TestConditionalModel.objects(id=t.id).count() == 0 # QuerySet path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): TestConditionalModel.objects(id=t.id).iff(count__ne=5).delete() TestConditionalModel.objects(id=t.id).iff(count__ne=2).delete() - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 0) + assert TestConditionalModel.objects(id=t.id).count() == 0 def test_update_lwt_ne(self): """ @@ -222,20 +223,20 @@ def test_update_lwt_ne(self): # DML path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): t.iff(count__ne=5).update(text='nothing') t.iff(count__ne=2).update(text='nothing') - self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + assert TestConditionalModel.objects(id=t.id).first().text == 'nothing' t.delete() # QuerySet path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): TestConditionalModel.objects(id=t.id).iff(count__ne=5).update(text='nothing') TestConditionalModel.objects(id=t.id).iff(count__ne=2).update(text='nothing') - self.assertEqual(TestConditionalModel.objects(id=t.id).first().text, 'nothing') + assert TestConditionalModel.objects(id=t.id).first().text == 'nothing' t.delete() def test_update_to_none(self): @@ -245,36 +246,36 @@ def test_update_to_none(self): # DML path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): t.iff(count=9999).update(text=None) - self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) + assert TestConditionalModel.objects(id=t.id).first().text is not None t.iff(count=5).update(text=None) - self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + assert TestConditionalModel.objects(id=t.id).first().text is None # QuerySet path t = TestConditionalModel.if_not_exists().create(text='something', count=5) - self.assertEqual(TestConditionalModel.objects(id=t.id).count(), 1) - with self.assertRaises(LWTException): + assert TestConditionalModel.objects(id=t.id).count() == 1 + with pytest.raises(LWTException): TestConditionalModel.objects(id=t.id).iff(count=9999).update(text=None) - self.assertIsNotNone(TestConditionalModel.objects(id=t.id).first().text) + assert TestConditionalModel.objects(id=t.id).first().text is not None TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None) - self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) + assert TestConditionalModel.objects(id=t.id).first().text is None def test_column_delete_after_update(self): # DML path t = TestConditionalModel.if_not_exists().create(text='something', count=5) t.iff(count=5).update(text=None, count=6) - self.assertIsNone(t.text) - self.assertEqual(t.count, 6) + assert t.text is None + assert t.count == 6 # QuerySet path t = TestConditionalModel.if_not_exists().create(text='something', count=5) TestConditionalModel.objects(id=t.id).iff(count=5).update(text=None, count=6) - self.assertIsNone(TestConditionalModel.objects(id=t.id).first().text) - self.assertEqual(TestConditionalModel.objects(id=t.id).first().count, 6) + assert TestConditionalModel.objects(id=t.id).first().text is None + assert TestConditionalModel.objects(id=t.id).first().count == 6 def test_conditional_without_instance(self): """ @@ -294,5 +295,5 @@ def test_conditional_without_instance(self): TestConditionalModel.iff(count=5).filter(id=uuid).update(text=None, count=6) t = TestConditionalModel.filter(id=uuid).first() - self.assertIsNone(t.text) - self.assertEqual(t.count, 6) + assert t.text is None + assert t.count == 6 diff --git a/tests/integration/cqlengine/test_timestamp.py b/tests/integration/cqlengine/test_timestamp.py index 3c57e20b7d..1b0e3d6b9d 100644 --- a/tests/integration/cqlengine/test_timestamp.py +++ b/tests/integration/cqlengine/test_timestamp.py @@ -22,6 +22,7 @@ from cassandra.cqlengine.models import Model from cassandra.cqlengine.query import BatchQuery from tests.integration.cqlengine.base import BaseCassEngTestCase +import pytest class TestTimestampModel(Model): @@ -135,12 +136,12 @@ def test_non_batch(self): tmp.timestamp(timedelta(seconds=5)).delete() - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) tmp = TestTimestampModel.create(id=uid, count=1) - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) # calling .timestamp sets the TS on the model @@ -166,12 +167,12 @@ def test_blind_delete(self): TestTimestampModel.objects(id=uid).timestamp(timedelta(seconds=5)).delete() - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) tmp = TestTimestampModel.create(id=uid, count=1) - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) def test_blind_delete_with_datetime(self): @@ -187,12 +188,12 @@ def test_blind_delete_with_datetime(self): TestTimestampModel.objects(id=uid).timestamp(plus_five_seconds).delete() - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) tmp = TestTimestampModel.create(id=uid, count=1) - with self.assertRaises(TestTimestampModel.DoesNotExist): + with pytest.raises(TestTimestampModel.DoesNotExist): TestTimestampModel.get(id=uid) def test_delete_in_the_past(self): diff --git a/tests/integration/cqlengine/test_ttl.py b/tests/integration/cqlengine/test_ttl.py index 4507e91ae7..df1afb6bf0 100644 --- a/tests/integration/cqlengine/test_ttl.py +++ b/tests/integration/cqlengine/test_ttl.py @@ -94,14 +94,14 @@ def test_ttl_included_on_create(self): TestTTLModel.ttl(60).create(text="hello blake") query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + assert "USING TTL" in query def test_queryset_is_returned_on_class(self): """ ensures we get a queryset descriptor back """ qs = TestTTLModel.ttl(60) - self.assertTrue(isinstance(qs, TestTTLModel.__queryset__), type(qs)) + assert isinstance(qs, TestTTLModel.__queryset__), type(qs) class TTLInstanceUpdateTest(BaseTTLTest): @@ -113,7 +113,7 @@ def test_update_includes_ttl(self): model.ttl(60).update(text="goodbye forever") query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + assert "USING TTL" in query def test_update_syntax_valid(self): # sanity test that ensures the TTL syntax is accepted by cassandra @@ -130,7 +130,7 @@ def test_instance_is_returned(self): o = TestTTLModel.create(text="whatever") o.text = "new stuff" o = o.ttl(60) - self.assertEqual(60, o._ttl) + assert 60 == o._ttl def test_ttl_is_include_with_query_on_update(self): session = get_session() @@ -143,7 +143,7 @@ def test_ttl_is_include_with_query_on_update(self): o.save() query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + assert "USING TTL" in query class TTLBlindUpdateTest(BaseTTLTest): @@ -157,7 +157,7 @@ def test_ttl_included_with_blind_update(self): TestTTLModel.objects(id=tid).ttl(60).update(text="bacon") query = m.call_args[0][0].query_string - self.assertIn("USING TTL", query) + assert "USING TTL" in query class TTLDefaultTest(BaseDefaultTTLTest): @@ -177,16 +177,16 @@ def test_default_ttl_not_set(self): o = TestTTLModel.create(text="some text") tid = o.id - self.assertIsNone(o._ttl) + assert o._ttl is None default_ttl = self.get_default_ttl('test_ttlmodel') - self.assertEqual(default_ttl, 0) + assert default_ttl == 0 with mock.patch.object(session, 'execute') as m: TestTTLModel.objects(id=tid).update(text="aligators") query = m.call_args[0][0].query_string - self.assertNotIn("USING TTL", query) + assert "USING TTL" not in query def test_default_ttl_set(self): session = get_session() @@ -195,29 +195,29 @@ def test_default_ttl_set(self): tid = o.id # Should not be set, it's handled by Cassandra - self.assertIsNone(o._ttl) + assert o._ttl is None default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 20) + assert default_ttl == 20 with mock.patch.object(session, 'execute') as m: TestTTLModel.objects(id=tid).update(text="aligators expired") # Should not be set either query = m.call_args[0][0].query_string - self.assertNotIn("USING TTL", query) + assert "USING TTL" not in query def test_default_ttl_modify(self): session = get_session() default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 20) + assert default_ttl == 20 TestDefaultTTLModel.__options__ = {'default_time_to_live': 10} sync_table(TestDefaultTTLModel) default_ttl = self.get_default_ttl('test_default_ttlmodel') - self.assertEqual(default_ttl, 10) + assert default_ttl == 10 # Restore default TTL TestDefaultTTLModel.__options__ = {'default_time_to_live': 20} @@ -229,10 +229,10 @@ def test_override_default_ttl(self): tid = o.id o.ttl(3600) - self.assertEqual(o._ttl, 3600) + assert o._ttl == 3600 with mock.patch.object(session, 'execute') as m: TestDefaultTTLModel.objects(id=tid).ttl(None).update(text="aligators expired") query = m.call_args[0][0].query_string - self.assertNotIn("USING TTL", query) + assert "USING TTL" not in query diff --git a/tests/integration/long/test_consistency.py b/tests/integration/long/test_consistency.py index 8f5fcb6313..48d0ca4ae2 100644 --- a/tests/integration/long/test_consistency.py +++ b/tests/integration/long/test_consistency.py @@ -17,6 +17,7 @@ import sys import time import traceback +import pytest from cassandra import ConsistencyLevel, OperationTimedOut, ReadTimeout, WriteTimeout, Unavailable from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT @@ -51,12 +52,12 @@ def setUp(self): self.coordinator_stats = CoordinatorStats() def _cl_failure(self, consistency_level, e): - self.fail('Instead of success, saw %s for CL.%s:\n\n%s' % ( + pytest.fail('Instead of success, saw %s for CL.%s:\n\n%s' % ( e, ConsistencyLevel.value_to_name[consistency_level], traceback.format_exc())) def _cl_expected_failure(self, cl): - self.fail('Test passed at ConsistencyLevel.%s:\n\n%s' % ( + pytest.fail('Test passed at ConsistencyLevel.%s:\n\n%s' % ( ConsistencyLevel.value_to_name[cl], traceback.format_exc())) def _insert(self, session, keyspace, count, consistency_level=ConsistencyLevel.ONE): @@ -101,9 +102,9 @@ def _assert_reads_succeed(self, session, keyspace, consistency_levels, expected_ self._query(session, keyspace, 1, cl) for i in range(3): if i == expected_reader: - self.coordinator_stats.assert_query_count_equals(self, i, 1) + self.coordinator_stats.assert_query_count_equals(i, 1) else: - self.coordinator_stats.assert_query_count_equals(self, i, 0) + self.coordinator_stats.assert_query_count_equals(i, 0) except Exception as e: self._cl_failure(cl, e) @@ -136,9 +137,9 @@ def _test_tokenaware_one_node_down(self, keyspace, rf, accepted): create_schema(cluster, session, keyspace, replication_factor=rf) self._insert(session, keyspace, count=1) self._query(session, keyspace, count=1) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 1) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 1) + self.coordinator_stats.assert_query_count_equals(3, 0) try: force_stop(2) @@ -188,9 +189,9 @@ def test_rfthree_tokenaware_none_down(self): create_schema(cluster, session, keyspace, replication_factor=3) self._insert(session, keyspace, count=1) self._query(session, keyspace, count=1) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 1) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 1) + self.coordinator_stats.assert_query_count_equals(3, 0) self.coordinator_stats.reset_counts() @@ -211,9 +212,9 @@ def _test_downgrading_cl(self, keyspace, rf, accepted): create_schema(cluster, session, keyspace, replication_factor=rf) self._insert(session, keyspace, 1) self._query(session, keyspace, 1) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 1) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 1) + self.coordinator_stats.assert_query_count_equals(3, 0) try: force_stop(2) @@ -268,13 +269,13 @@ def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): self._query(session, keyspace, count=12) if roundrobin: - self.coordinator_stats.assert_query_count_equals(self, 1, 4) - self.coordinator_stats.assert_query_count_equals(self, 2, 4) - self.coordinator_stats.assert_query_count_equals(self, 3, 4) + self.coordinator_stats.assert_query_count_equals(1, 4) + self.coordinator_stats.assert_query_count_equals(2, 4) + self.coordinator_stats.assert_query_count_equals(3, 4) else: - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) try: self.coordinator_stats.reset_counts() @@ -288,13 +289,13 @@ def rfthree_downgradingcl(self, cluster, keyspace, roundrobin): self.coordinator_stats.reset_counts() self._query(session, keyspace, 12, consistency_level=cl) if roundrobin: - self.coordinator_stats.assert_query_count_equals(self, 1, 6) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 6) + self.coordinator_stats.assert_query_count_equals(1, 6) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 6) else: - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 12) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 12) finally: start(2) wait_for_up(cluster, 2) @@ -360,6 +361,3 @@ def test_pool_with_host_down(self): start(node_to_stop) wait_for_up(cluster, node_to_stop) cluster.shutdown() - - - diff --git a/tests/integration/long/test_failure_types.py b/tests/integration/long/test_failure_types.py index ea8897185a..33d2c99130 100644 --- a/tests/integration/long/test_failure_types.py +++ b/tests/integration/long/test_failure_types.py @@ -32,8 +32,8 @@ get_node, start_cluster_wait_for_up, requiresmallclockgranularity, local, CASSANDRA_VERSION, TestCluster) - import unittest +import pytest log = logging.getLogger(__name__) @@ -157,13 +157,13 @@ def _perform_cql_statement(self, text, consistency_level, expected_exception, se if expected_exception is None: self.execute_helper(session, statement) else: - with self.assertRaises(expected_exception) as cm: + with pytest.raises(expected_exception) as cm: self.execute_helper(session, statement) if ProtocolVersion.uses_error_code_map(PROTOCOL_VERSION): - if isinstance(cm.exception, ReadFailure): - self.assertEqual(list(cm.exception.error_code_map.values())[0], 1) - if isinstance(cm.exception, WriteFailure): - self.assertEqual(list(cm.exception.error_code_map.values())[0], 0) + if isinstance(cm.value, ReadFailure): + assert list(cm.value.error_code_map.values())[0] == 1 + if isinstance(cm.value, WriteFailure): + assert list(cm.value.error_code_map.values())[0] == 0 def test_write_failures_from_coordinator(self): """ @@ -380,13 +380,13 @@ def test_async_timeouts(self): # Test with default timeout (should be 10) start_time = time.time() future = self.session.execute_async(ss) - with self.assertRaises(OperationTimedOut): + with pytest.raises(OperationTimedOut): future.result() end_time = time.time() total_time = end_time-start_time expected_time = self.cluster.profile_manager.default.request_timeout # check timeout and ensure it's within a reasonable range - self.assertAlmostEqual(expected_time, total_time, delta=.05) + assert expected_time == pytest.approx(total_time, abs=.05) # Test with user defined timeout (Should be 1) expected_time = 1 @@ -397,11 +397,11 @@ def test_async_timeouts(self): future.add_callback(mock_callback) future.add_errback(mock_errorback) - with self.assertRaises(OperationTimedOut): + with pytest.raises(OperationTimedOut): future.result() end_time = time.time() total_time = end_time-start_time # check timeout and ensure it's within a reasonable range - self.assertAlmostEqual(expected_time, total_time, delta=.05) - self.assertTrue(mock_errorback.called) - self.assertFalse(mock_callback.called) + assert expected_time == pytest.approx(total_time, abs=.05) + assert mock_errorback.called + assert not mock_callback.called diff --git a/tests/integration/long/test_ipv6.py b/tests/integration/long/test_ipv6.py index d58bad987a..1d2c7b2874 100644 --- a/tests/integration/long/test_ipv6.py +++ b/tests/integration/long/test_ipv6.py @@ -38,6 +38,7 @@ import unittest +import pytest # If more modules do IPV6 testing, this can be moved down to integration.__init__. @@ -83,22 +84,23 @@ def test_connect(self): session = cluster.connect() future = session.execute_async("SELECT * FROM system.local WHERE key='local'") future.result() - self.assertEqual(future._current_host.address, '::1') + assert future._current_host.address == '::1' cluster.shutdown() def test_error(self): cluster = TestCluster(connection_class=self.connection_class, contact_points=['::1'], port=9043, connect_timeout=10) - self.assertRaisesRegex(NoHostAvailable, '\(\'Unable to connect.*%s.*::1\', 9043.*Connection refused.*' - % errno.ECONNREFUSED, cluster.connect) + with pytest.raises(NoHostAvailable, match='\(\'Unable to connect.*%s.*::1\', 9043.*Connection refused.*' + % errno.ECONNREFUSED): + cluster.connect() def test_error_multiple(self): if len(socket.getaddrinfo('localhost', 9043, socket.AF_UNSPEC, socket.SOCK_STREAM)) < 2: raise unittest.SkipTest('localhost only resolves one address') cluster = TestCluster(connection_class=self.connection_class, contact_points=['localhost'], port=9043, connect_timeout=10) - self.assertRaisesRegex(NoHostAvailable, '\(\'Unable to connect.*Tried connecting to \[\(.*\(.*\].*Last error', - cluster.connect) + with pytest.raises(NoHostAvailable, match='\(\'Unable to connect.*Tried connecting to \[\(.*\(.*\].*Last error'): + cluster.connect() class LibevConnectionTests(IPV6ConnectionTest, unittest.TestCase): diff --git a/tests/integration/long/test_large_data.py b/tests/integration/long/test_large_data.py index 59873204a4..a79a6da4f5 100644 --- a/tests/integration/long/test_large_data.py +++ b/tests/integration/long/test_large_data.py @@ -28,6 +28,7 @@ from tests.integration.long.utils import create_schema import unittest +import pytest log = logging.getLogger(__name__) @@ -115,7 +116,7 @@ def test_wide_rows(self): # Verify for i, row in enumerate(results): - self.assertAlmostEqual(row['i'], i, delta=3) + assert row['i'] == pytest.approx(i, abs=3) session.cluster.shutdown() @@ -161,11 +162,11 @@ def test_wide_batch_rows(self): lastvalue = 0 for j, row in enumerate(results): lastValue=row['i'] - self.assertEqual(lastValue, j) + assert lastValue == j #check the last value make sure it's what we expect index_value = to_insert-1 - self.assertEqual(lastValue,index_value,"Verification failed only found {0} inserted we were expecting {1}".format(j,index_value)) + assert lastValue == index_value, "Verification failed only found {0} inserted we were expecting {1}".format(j,index_value) session.cluster.shutdown() @@ -200,9 +201,9 @@ def test_wide_byte_rows(self): # Verify bb = pack('>H', 0xCAFE) for i, row in enumerate(results): - self.assertEqual(row['v'], bb) + assert row['v'] == bb - self.assertGreaterEqual(i, expected_results, "Verification failed only found {0} inserted we were expecting {1}".format(i,expected_results)) + assert i >= expected_results, "Verification failed only found {0} inserted we were expecting {1}".format(i,expected_results) session.cluster.shutdown() @@ -235,9 +236,9 @@ def test_large_text(self): # Verify found_result = False for i, row in enumerate(result): - self.assertEqual(row['txt'], text) + assert row['txt'] == text found_result = True - self.assertTrue(found_result, "No results were found") + assert found_result, "No results were found" session.cluster.shutdown() @@ -266,6 +267,6 @@ def test_wide_table(self): # Verify for row in result: for i in range(table_width): - self.assertEqual(row[create_column_name(i)], i) + assert row[create_column_name(i)] == i session.cluster.shutdown() diff --git a/tests/integration/long/test_loadbalancingpolicies.py b/tests/integration/long/test_loadbalancingpolicies.py index a6dff4d786..fd8edde14c 100644 --- a/tests/integration/long/test_loadbalancingpolicies.py +++ b/tests/integration/long/test_loadbalancingpolicies.py @@ -16,6 +16,7 @@ import struct import sys import traceback +import pytest from cassandra import cqltypes from cassandra import ConsistencyLevel, Unavailable, OperationTimedOut, ReadTimeout, ReadFailure, \ @@ -185,9 +186,9 @@ def test_token_aware_is_used_by_default(self): self.addCleanup(cluster.shutdown) if murmur3 is not None: - self.assertTrue(isinstance(cluster.profile_manager.default.load_balancing_policy, TokenAwarePolicy)) + assert isinstance(cluster.profile_manager.default.load_balancing_policy, TokenAwarePolicy) else: - self.assertTrue(isinstance(cluster.profile_manager.default.load_balancing_policy, DCAwareRoundRobinPolicy)) + assert isinstance(cluster.profile_manager.default.load_balancing_policy, DCAwareRoundRobinPolicy) def test_roundrobin(self): use_singledc() @@ -200,9 +201,9 @@ def test_roundrobin(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 4) - self.coordinator_stats.assert_query_count_equals(self, 2, 4) - self.coordinator_stats.assert_query_count_equals(self, 3, 4) + self.coordinator_stats.assert_query_count_equals(1, 4) + self.coordinator_stats.assert_query_count_equals(2, 4) + self.coordinator_stats.assert_query_count_equals(3, 4) force_stop(3) self._wait_for_nodes_down([3], cluster) @@ -210,9 +211,9 @@ def test_roundrobin(self): self.coordinator_stats.reset_counts() self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 6) - self.coordinator_stats.assert_query_count_equals(self, 2, 6) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 6) + self.coordinator_stats.assert_query_count_equals(2, 6) + self.coordinator_stats.assert_query_count_equals(3, 0) decommission(1) start(3) @@ -222,9 +223,9 @@ def test_roundrobin(self): self.coordinator_stats.reset_counts() self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 6) - self.coordinator_stats.assert_query_count_equals(self, 3, 6) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 6) + self.coordinator_stats.assert_query_count_equals(3, 6) def test_roundrobin_two_dcs(self): use_multidc([2, 2]) @@ -237,10 +238,10 @@ def test_roundrobin_two_dcs(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 3) - self.coordinator_stats.assert_query_count_equals(self, 2, 3) - self.coordinator_stats.assert_query_count_equals(self, 3, 3) - self.coordinator_stats.assert_query_count_equals(self, 4, 3) + self.coordinator_stats.assert_query_count_equals(1, 3) + self.coordinator_stats.assert_query_count_equals(2, 3) + self.coordinator_stats.assert_query_count_equals(3, 3) + self.coordinator_stats.assert_query_count_equals(4, 3) force_stop(1) bootstrap(5, 'dc3') @@ -253,11 +254,11 @@ def test_roundrobin_two_dcs(self): self.coordinator_stats.reset_counts() self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 3) - self.coordinator_stats.assert_query_count_equals(self, 3, 3) - self.coordinator_stats.assert_query_count_equals(self, 4, 3) - self.coordinator_stats.assert_query_count_equals(self, 5, 3) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 3) + self.coordinator_stats.assert_query_count_equals(3, 3) + self.coordinator_stats.assert_query_count_equals(4, 3) + self.coordinator_stats.assert_query_count_equals(5, 3) def test_roundrobin_two_dcs_2(self): use_multidc([2, 2]) @@ -270,10 +271,10 @@ def test_roundrobin_two_dcs_2(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 3) - self.coordinator_stats.assert_query_count_equals(self, 2, 3) - self.coordinator_stats.assert_query_count_equals(self, 3, 3) - self.coordinator_stats.assert_query_count_equals(self, 4, 3) + self.coordinator_stats.assert_query_count_equals(1, 3) + self.coordinator_stats.assert_query_count_equals(2, 3) + self.coordinator_stats.assert_query_count_equals(3, 3) + self.coordinator_stats.assert_query_count_equals(4, 3) force_stop(1) bootstrap(5, 'dc1') @@ -286,11 +287,11 @@ def test_roundrobin_two_dcs_2(self): self.coordinator_stats.reset_counts() self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 3) - self.coordinator_stats.assert_query_count_equals(self, 3, 3) - self.coordinator_stats.assert_query_count_equals(self, 4, 3) - self.coordinator_stats.assert_query_count_equals(self, 5, 3) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 3) + self.coordinator_stats.assert_query_count_equals(3, 3) + self.coordinator_stats.assert_query_count_equals(4, 3) + self.coordinator_stats.assert_query_count_equals(5, 3) def test_dc_aware_roundrobin_two_dcs(self): use_multidc([3, 2]) @@ -303,11 +304,11 @@ def test_dc_aware_roundrobin_two_dcs(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 4) - self.coordinator_stats.assert_query_count_equals(self, 2, 4) - self.coordinator_stats.assert_query_count_equals(self, 3, 4) - self.coordinator_stats.assert_query_count_equals(self, 4, 0) - self.coordinator_stats.assert_query_count_equals(self, 5, 0) + self.coordinator_stats.assert_query_count_equals(1, 4) + self.coordinator_stats.assert_query_count_equals(2, 4) + self.coordinator_stats.assert_query_count_equals(3, 4) + self.coordinator_stats.assert_query_count_equals(4, 0) + self.coordinator_stats.assert_query_count_equals(5, 0) def test_dc_aware_roundrobin_two_dcs_2(self): use_multidc([3, 2]) @@ -320,11 +321,11 @@ def test_dc_aware_roundrobin_two_dcs_2(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) - self.coordinator_stats.assert_query_count_equals(self, 4, 6) - self.coordinator_stats.assert_query_count_equals(self, 5, 6) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 0) + self.coordinator_stats.assert_query_count_equals(4, 6) + self.coordinator_stats.assert_query_count_equals(5, 6) def test_dc_aware_roundrobin_one_remote_host(self): use_multidc([2, 2]) @@ -337,10 +338,10 @@ def test_dc_aware_roundrobin_one_remote_host(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 6) - self.coordinator_stats.assert_query_count_equals(self, 4, 6) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 6) + self.coordinator_stats.assert_query_count_equals(4, 6) self.coordinator_stats.reset_counts() bootstrap(5, 'dc1') @@ -348,11 +349,11 @@ def test_dc_aware_roundrobin_one_remote_host(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 6) - self.coordinator_stats.assert_query_count_equals(self, 4, 6) - self.coordinator_stats.assert_query_count_equals(self, 5, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 6) + self.coordinator_stats.assert_query_count_equals(4, 6) + self.coordinator_stats.assert_query_count_equals(5, 0) self.coordinator_stats.reset_counts() decommission(3) @@ -361,12 +362,12 @@ def test_dc_aware_roundrobin_one_remote_host(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) - self.coordinator_stats.assert_query_count_equals(self, 4, 0) + self.coordinator_stats.assert_query_count_equals(3, 0) + self.coordinator_stats.assert_query_count_equals(4, 0) responses = set() for node in [1, 2, 5]: responses.add(self.coordinator_stats.get_query_count(node)) - self.assertEqual(set([0, 0, 12]), responses) + assert set([0, 0, 12]) == responses self.coordinator_stats.reset_counts() decommission(5) @@ -374,13 +375,13 @@ def test_dc_aware_roundrobin_one_remote_host(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) - self.coordinator_stats.assert_query_count_equals(self, 4, 0) - self.coordinator_stats.assert_query_count_equals(self, 5, 0) + self.coordinator_stats.assert_query_count_equals(3, 0) + self.coordinator_stats.assert_query_count_equals(4, 0) + self.coordinator_stats.assert_query_count_equals(5, 0) responses = set() for node in [1, 2]: responses.add(self.coordinator_stats.get_query_count(node)) - self.assertEqual(set([0, 12]), responses) + assert set([0, 12]) == responses self.coordinator_stats.reset_counts() decommission(1) @@ -388,20 +389,17 @@ def test_dc_aware_roundrobin_one_remote_host(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) - self.coordinator_stats.assert_query_count_equals(self, 4, 0) - self.coordinator_stats.assert_query_count_equals(self, 5, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) + self.coordinator_stats.assert_query_count_equals(4, 0) + self.coordinator_stats.assert_query_count_equals(5, 0) self.coordinator_stats.reset_counts() force_stop(2) - try: + with pytest.raises(NoHostAvailable): self._query(session, keyspace) - self.fail() - except NoHostAvailable: - pass def test_token_aware(self): keyspace = 'test_token_aware' @@ -421,28 +419,26 @@ def token_aware(self, keyspace, use_prepared=False): self._insert(session, keyspace) self._query(session, keyspace, use_prepared=use_prepared) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) self.coordinator_stats.reset_counts() self._query(session, keyspace, use_prepared=use_prepared) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) self.coordinator_stats.reset_counts() force_stop(2) self._wait_for_nodes_down([2], cluster) - try: + with pytest.raises(Unavailable) as e: self._query(session, keyspace, use_prepared=use_prepared) - self.fail() - except Unavailable as e: - self.assertEqual(e.consistency, 1) - self.assertEqual(e.required_replicas, 1) - self.assertEqual(e.alive_replicas, 0) + assert e.value.consistency == 1 + assert e.value.required_replicas == 1 + assert e.value.alive_replicas == 0 self.coordinator_stats.reset_counts() start(2) @@ -450,19 +446,16 @@ def token_aware(self, keyspace, use_prepared=False): self._query(session, keyspace, use_prepared=use_prepared) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) self.coordinator_stats.reset_counts() stop(2) self._wait_for_nodes_down([2], cluster) - try: + with pytest.raises(Unavailable): self._query(session, keyspace, use_prepared=use_prepared) - self.fail() - except Unavailable: - pass self.coordinator_stats.reset_counts() start(2) @@ -476,8 +469,8 @@ def token_aware(self, keyspace, use_prepared=False): self.coordinator_stats.get_query_count(1), self.coordinator_stats.get_query_count(3) ]) - self.assertEqual(results, set([0, 12])) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) + assert results == set([0, 12]) + self.coordinator_stats.assert_query_count_equals(2, 0) def test_token_aware_composite_key(self): use_singledc() @@ -500,18 +493,16 @@ def test_token_aware_composite_key(self): '(?, ?, ?)' % table) bound = prepared.bind((1, 2, 3)) result = session.execute(bound) - self.assertIn(result.response_future.attempted_hosts[0], - cluster.metadata.get_replicas(keyspace, bound.routing_key)) + assert result.response_future.attempted_hosts[0] in cluster.metadata.get_replicas(keyspace, bound.routing_key) # There could be race condition with querying a node # which doesn't yet have the data so we query one of # the replicas results = session.execute(SimpleStatement('SELECT * FROM %s WHERE k1 = 1 AND k2 = 2' % table, routing_key=bound.routing_key)) - self.assertIn(results.response_future.attempted_hosts[0], - cluster.metadata.get_replicas(keyspace, bound.routing_key)) + assert results.response_future.attempted_hosts[0] in cluster.metadata.get_replicas(keyspace, bound.routing_key) - self.assertTrue(results[0].i) + assert results[0].i def test_token_aware_with_rf_2(self, use_prepared=False): use_singledc() @@ -524,9 +515,9 @@ def test_token_aware_with_rf_2(self, use_prepared=False): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) self.coordinator_stats.reset_counts() stop(2) @@ -534,9 +525,9 @@ def test_token_aware_with_rf_2(self, use_prepared=False): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 12) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 12) def test_token_aware_with_local_table(self): use_singledc() @@ -547,7 +538,7 @@ def test_token_aware_with_local_table(self): p = session.prepare("SELECT * FROM system.local WHERE key=?") # this would blow up prior to 61b4fad r = session.execute(p, ('local',)) - self.assertEqual(r[0].key, 'local') + assert r[0].key == 'local' def test_token_aware_with_shuffle_rf2(self): """ @@ -572,9 +563,9 @@ def test_token_aware_with_shuffle_rf2(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 12) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 12) def test_token_aware_with_shuffle_rf3(self): """ @@ -599,10 +590,10 @@ def test_token_aware_with_shuffle_rf3(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) query_count_two = self.coordinator_stats.get_query_count(2) query_count_three = self.coordinator_stats.get_query_count(3) - self.assertEqual(query_count_two + query_count_three, 12) + assert query_count_two + query_count_three == 12 self.coordinator_stats.reset_counts() stop(2) @@ -610,9 +601,9 @@ def test_token_aware_with_shuffle_rf3(self): self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 0) - self.coordinator_stats.assert_query_count_equals(self, 3, 12) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) + self.coordinator_stats.assert_query_count_equals(3, 12) @greaterthanorequalcass40 def test_token_aware_with_transient_replication(self): @@ -644,15 +635,15 @@ def test_token_aware_with_transient_replication(self): f = session.execute_async(query, (i,), trace=True) full_dc1_replicas = [h for h in cluster.metadata.get_replicas('test_tr', cqltypes.Int32Type.serialize(i, cluster.protocol_version)) if h.datacenter == 'dc1'] - self.assertEqual(len(full_dc1_replicas), 2) + assert len(full_dc1_replicas) == 2 f.result() trace_hosts = [cluster.metadata.get_host(e.source) for e in f.get_query_trace().events] for h in f.attempted_hosts: - self.assertIn(h, full_dc1_replicas) + assert h in full_dc1_replicas for h in trace_hosts: - self.assertIn(h, full_dc1_replicas) + assert h in full_dc1_replicas def _set_up_shuffle_test(self, keyspace, replication_factor): @@ -681,7 +672,7 @@ def _check_query_order_changes(self, session, keyspace): self.coordinator_stats.get_query_count(3)) query_counts.add(loop_qcs) - self.assertEqual(sum(loop_qcs), 12) + assert sum(loop_qcs) == 12 # end the loop if we get more than one query ordering self.coordinator_stats.reset_counts() @@ -706,24 +697,21 @@ def test_white_list(self): self._insert(session, keyspace) self._query(session, keyspace) - self.coordinator_stats.assert_query_count_equals(self, 1, 0) - self.coordinator_stats.assert_query_count_equals(self, 2, 12) - self.coordinator_stats.assert_query_count_equals(self, 3, 0) + self.coordinator_stats.assert_query_count_equals(1, 0) + self.coordinator_stats.assert_query_count_equals(2, 12) + self.coordinator_stats.assert_query_count_equals(3, 0) # white list policy should not allow reconnecting to ignored hosts force_stop(3) self._wait_for_nodes_down([3]) - self.assertFalse(cluster.metadata.get_host(IP_FORMAT % 3).is_currently_reconnecting()) + assert not cluster.metadata.get_host(IP_FORMAT % 3).is_currently_reconnecting() self.coordinator_stats.reset_counts() force_stop(2) self._wait_for_nodes_down([2]) - try: + with pytest.raises(NoHostAvailable): self._query(session, keyspace) - self.fail() - except NoHostAvailable: - pass def test_black_list_with_host_filter_policy(self): """ @@ -752,7 +740,7 @@ def test_black_list_with_host_filter_policy(self): session = cluster.connect() self._wait_for_nodes_up([1, 2, 3]) - self.assertNotIn(ignored_address, [h.address for h in hfp.make_query_plan()]) + assert ignored_address not in [h.address for h in hfp.make_query_plan()] create_schema(cluster, session, keyspace) self._insert(session, keyspace) @@ -763,12 +751,12 @@ def test_black_list_with_host_filter_policy(self): # will be 4 and for the other 8 first_node_count = self.coordinator_stats.get_query_count(1) third_node_count = self.coordinator_stats.get_query_count(3) - self.assertEqual(first_node_count + third_node_count, 12) - self.assertTrue(first_node_count == 8 or first_node_count == 4) + assert first_node_count + third_node_count == 12 + assert first_node_count == 8 or first_node_count == 4 - self.coordinator_stats.assert_query_count_equals(self, 2, 0) + self.coordinator_stats.assert_query_count_equals(2, 0) # policy should not allow reconnecting to ignored host force_stop(2) self._wait_for_nodes_down([2]) - self.assertFalse(cluster.metadata.get_host(ignored_address).is_currently_reconnecting()) + assert not cluster.metadata.get_host(ignored_address).is_currently_reconnecting() diff --git a/tests/integration/long/test_policies.py b/tests/integration/long/test_policies.py index 33f35ced0d..ab8d125ab1 100644 --- a/tests/integration/long/test_policies.py +++ b/tests/integration/long/test_policies.py @@ -18,6 +18,7 @@ from cassandra.cluster import ExecutionProfile, EXEC_PROFILE_DEFAULT from tests.integration import use_cluster, get_cluster, get_node, TestCluster +import pytest def setup_module(): @@ -58,10 +59,10 @@ def test_should_rethrow_on_unvailable_with_default_policy_if_cas(self): # supported as conditional update commit consistency. ...."" # after fix: cassandra.Unavailable (expected since replicas are down) - with self.assertRaises(Unavailable) as cm: + with pytest.raises(Unavailable) as cm: session.execute("update test_retry_policy_cas.t set data = 'staging' where id = 42 if data ='testing'") - exception = cm.exception - self.assertEqual(exception.consistency, ConsistencyLevel.SERIAL) - self.assertEqual(exception.required_replicas, 2) - self.assertEqual(exception.alive_replicas, 1) + exception = cm.value + assert exception.consistency == ConsistencyLevel.SERIAL + assert exception.required_replicas == 2 + assert exception.alive_replicas == 1 diff --git a/tests/integration/long/test_schema.py b/tests/integration/long/test_schema.py index f1cc80a17a..f892acba52 100644 --- a/tests/integration/long/test_schema.py +++ b/tests/integration/long/test_schema.py @@ -156,6 +156,6 @@ def test_for_schema_disagreement_attribute(self): def check_and_wait_for_agreement(self, session, rs, exepected): # Wait for RESULT_KIND_SCHEMA_CHANGE message to arrive time.sleep(1) - self.assertEqual(rs.response_future.is_schema_agreed, exepected) + assert rs.response_future.is_schema_agreed == exepected if not rs.response_future.is_schema_agreed: session.cluster.control_connection.wait_for_schema_agreement(wait_time=1000) diff --git a/tests/integration/long/test_ssl.py b/tests/integration/long/test_ssl.py index 070e2fe268..56dc6a5c2d 100644 --- a/tests/integration/long/test_ssl.py +++ b/tests/integration/long/test_ssl.py @@ -25,6 +25,7 @@ from tests.integration import ( get_cluster, remove_cluster, use_single_node, start_cluster_wait_for_up, EVENT_LOOP_MANAGER, TestCluster ) +import pytest if not hasattr(ssl, 'match_hostname'): try: @@ -290,7 +291,7 @@ def test_cannot_connect_without_client_auth(self): cluster = TestCluster(ssl_options={'ca_certs': CLIENT_CA_CERTS, 'ssl_version': ssl_version}) - with self.assertRaises(NoHostAvailable) as _: + with pytest.raises(NoHostAvailable): cluster.connect() cluster.shutdown() @@ -322,7 +323,7 @@ def test_cannot_connect_with_bad_client_auth(self): 'keyfile': DRIVER_KEYFILE} ) - with self.assertRaises(NoHostAvailable) as _: + with pytest.raises(NoHostAvailable): cluster.connect() cluster.shutdown() @@ -333,7 +334,7 @@ def test_cannot_connect_with_invalid_hostname(self): 'certfile': DRIVER_CERTFILE} ssl_options.update(verify_certs) - with self.assertRaises(Exception): + with pytest.raises(Exception): validate_ssl_options(ssl_options=ssl_options, hostname='localhost') @@ -487,7 +488,7 @@ def test_cannot_connect_ssl_context_with_invalid_hostname(self): ) ssl_context.verify_mode = ssl.CERT_REQUIRED ssl_options["check_hostname"] = True - with self.assertRaises(Exception): + with pytest.raises(Exception): validate_ssl_options(ssl_context=ssl_context, ssl_options=ssl_options, hostname="localhost") @unittest.skipIf(USES_PYOPENSSL, "This test is for the built-in ssl.Context") diff --git a/tests/integration/long/test_topology_change.py b/tests/integration/long/test_topology_change.py index 5b12eef28c..80540cfb2f 100644 --- a/tests/integration/long/test_topology_change.py +++ b/tests/integration/long/test_topology_change.py @@ -39,10 +39,10 @@ def test_removed_node_stops_reconnecting(self): get_node(3).nodetool("disablebinary") wait_until(condition=lambda: state_listener.downed_host is not None, delay=2, max_attempts=50) - self.assertTrue(state_listener.downed_host.is_currently_reconnecting()) + assert state_listener.downed_host.is_currently_reconnecting() decommission(3) wait_until(condition=lambda: state_listener.removed_host is not None, delay=2, max_attempts=50) - self.assertIs(state_listener.downed_host, state_listener.removed_host) # Just a sanity check - self.assertFalse(state_listener.removed_host.is_currently_reconnecting()) + assert state_listener.downed_host is state_listener.removed_host # Just a sanity check + assert not state_listener.removed_host.is_currently_reconnecting() diff --git a/tests/integration/long/utils.py b/tests/integration/long/utils.py index a3ae705a34..93464df8ff 100644 --- a/tests/integration/long/utils.py +++ b/tests/integration/long/utils.py @@ -14,6 +14,7 @@ import logging import time +import pytest from collections import defaultdict from packaging.version import Version @@ -48,10 +49,10 @@ def get_query_count(self, node): ip = '127.0.0.%d' % node return self.coordinator_counts[ip] - def assert_query_count_equals(self, testcase, node, expected): + def assert_query_count_equals(self, node, expected): ip = '127.0.0.%d' % node if self.get_query_count(node) != expected: - testcase.fail('Expected %d queries to %s, but got %d. Query counts: %s' % ( + pytest.fail('Expected %d queries to %s, but got %d. Query counts: %s' % ( expected, ip, self.coordinator_counts[ip], dict(self.coordinator_counts))) diff --git a/tests/integration/simulacron/test_backpressure.py b/tests/integration/simulacron/test_backpressure.py index 69c38da8fe..0b84f73e29 100644 --- a/tests/integration/simulacron/test_backpressure.py +++ b/tests/integration/simulacron/test_backpressure.py @@ -19,6 +19,7 @@ from tests.integration import requiressimulacron, libevtest from tests.integration.simulacron import SimulacronBase, PROTOCOL_VERSION from tests.integration.simulacron.utils import ResumeReads, PauseReads, prime_request, start_and_prime_singledc +import pytest @requiressimulacron @@ -70,10 +71,10 @@ def test_paused_connections(self): # Make sure we actually have some stuck in-flight requests for in_flight in [pool._connection.in_flight for pool in session.get_pools()]: - self.assertGreater(in_flight, 100) + assert in_flight > 100 time.sleep(.5) for in_flight in [pool._connection.in_flight for pool in session.get_pools()]: - self.assertGreater(in_flight, 100) + assert in_flight > 100 prime_request(ResumeReads()) @@ -83,7 +84,7 @@ def test_paused_connections(self): except NoHostAvailable as e: # We shouldn't have any timeouts here, but all of the queries beyond what can fit # in the tcp buffer will have returned with a ConnectionBusy exception - self.assertIn("ConnectionBusy", str(e)) + assert "ConnectionBusy" in str(e) # Verify that we can continue sending queries without any problems for host in session.cluster.metadata.all_hosts(): @@ -121,9 +122,9 @@ def test_queued_requests_timeout(self): # Simulacron will respond to a couple queries before cutting off reads, so we'll just verify # that only "a few" successes happened here - self.assertLess(successes, 50) - self.assertLess(self.callback_successes, 50) - self.assertEqual(self.callback_errors, len(futures) - self.callback_successes) + assert successes < 50 + assert self.callback_successes < 50 + assert self.callback_errors == len(futures) - self.callback_successes def test_cluster_busy(self): """ Verify that once TCP buffer is full we get busy exceptions rather than timeouts """ @@ -146,9 +147,9 @@ def test_cluster_busy(self): # Now that our send buffer is completely full, verify we immediately get busy exceptions rather than timing out for i in range(1000): - with self.assertRaises(NoHostAvailable) as e: + with pytest.raises(NoHostAvailable) as e: session.execute(query, [str(i)]) - self.assertIn("ConnectionBusy", str(e.exception)) + assert "ConnectionBusy" in str(e.value) def test_node_busy(self): """ Verify that once TCP buffer is full, queries continue to get re-routed to other nodes """ @@ -176,4 +177,3 @@ def test_node_busy(self): # verify queries get re-routed to other nodes and queries complete successfully for i in range(1000): session.execute(query, [str(i)]) - diff --git a/tests/integration/simulacron/test_cluster.py b/tests/integration/simulacron/test_cluster.py index 53aa9936fc..898734c416 100644 --- a/tests/integration/simulacron/test_cluster.py +++ b/tests/integration/simulacron/test_cluster.py @@ -24,6 +24,7 @@ from cassandra import (WriteTimeout, WriteType, ConsistencyLevel, UnresolvableContactPoints) from cassandra.cluster import Cluster, ControlConnection +import pytest PROTOCOL_VERSION = min(4, PROTOCOL_VERSION) @@ -48,17 +49,17 @@ def test_writetimeout(self): } prime_query(query_to_prime_simple, then=then, rows=None, column_types=None) - with self.assertRaises(WriteTimeout) as assert_raised_context: + with pytest.raises(WriteTimeout) as assert_raised_context: self.session.execute(query_to_prime_simple) - wt = assert_raised_context.exception - self.assertEqual(wt.write_type, WriteType.name_to_value[write_type]) - self.assertEqual(wt.consistency, ConsistencyLevel.name_to_value[consistency]) - self.assertEqual(wt.received_responses, received_responses) - self.assertEqual(wt.required_responses, required_responses) - self.assertIn(write_type, str(wt)) - self.assertIn(consistency, str(wt)) - self.assertIn(str(received_responses), str(wt)) - self.assertIn(str(required_responses), str(wt)) + wt = assert_raised_context.value + assert wt.write_type == WriteType.name_to_value[write_type] + assert wt.consistency == ConsistencyLevel.name_to_value[consistency] + assert wt.received_responses == received_responses + assert wt.required_responses == required_responses + assert write_type in str(wt) + assert consistency in str(wt) + assert str(received_responses) in str(wt) + assert str(required_responses) in str(wt) @requiressimulacron @@ -77,7 +78,7 @@ def test_connection_with_one_unresolvable_contact_point(self): compression=False) def test_connection_with_only_unresolvable_contact_points(self): - with self.assertRaises(UnresolvableContactPoints): + with pytest.raises(UnresolvableContactPoints): self.cluster = Cluster(['dns.invalid'], protocol_version=PROTOCOL_VERSION, compression=False) @@ -102,6 +103,6 @@ def test_duplicate(self): session = cluster.connect(wait_for_all_pools=True) warnings = mock_handler.messages.get("warning") - self.assertEqual(len(warnings), 1) - self.assertTrue('multiple hosts with the same endpoint' in warnings[0]) + assert len(warnings) == 1 + assert 'multiple hosts with the same endpoint' in warnings[0] cluster.shutdown() diff --git a/tests/integration/simulacron/test_connection.py b/tests/integration/simulacron/test_connection.py index 95df69e44c..818d0b46b9 100644 --- a/tests/integration/simulacron/test_connection.py +++ b/tests/integration/simulacron/test_connection.py @@ -36,6 +36,7 @@ start_and_prime_singledc, clear_queries, RejectConnections, RejectType, AcceptConnections, PauseReads, ResumeReads) +import pytest class TrackDownListener(HostStateListener): @@ -141,7 +142,7 @@ def test_heart_beat_timeout(self): for f in futures: f._event.wait() - self.assertIsInstance(f._final_exception, OperationTimedOut) + assert isinstance(f._final_exception, OperationTimedOut) prime_request(PrimeOptions(then=NO_THEN)) @@ -150,10 +151,10 @@ def test_heart_beat_timeout(self): time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2.5) for host in cluster.metadata.all_hosts(): - self.assertIn(host, listener.hosts_marked_down) + assert host in listener.hosts_marked_down # In this case HostConnection._replace shouldn't be called - self.assertNotIn("_replace", executor.called_functions) + assert "_replace" not in executor.called_functions def test_callbacks_and_pool_when_oto(self): """ @@ -181,9 +182,10 @@ def test_callbacks_and_pool_when_oto(self): future = session.execute_async(query_to_prime, timeout=1) callback, errback = Mock(name='callback'), Mock(name='errback') future.add_callbacks(callback, errback) - self.assertRaises(OperationTimedOut, future.result) + with pytest.raises(OperationTimedOut): + future.result() - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(cluster) time.sleep(server_delay + 1) # PYTHON-630 -- only the errback should be called @@ -261,7 +263,8 @@ def connection_factory(self, *args, **kwargs): prime_request(PrimeOptions(then={"result": "no_result", "delay_in_ms": never})) prime_request(RejectConnections("unbind")) - self.assertRaisesRegex(OperationTimedOut, "Connection defunct by heartbeat", future.result) + with pytest.raises(OperationTimedOut, match="Connection defunct by heartbeat"): + future.result() def test_close_when_query(self): """ @@ -289,7 +292,8 @@ def test_close_when_query(self): } prime_query(query_to_prime, rows=None, column_types=None, then=then) - self.assertRaises(NoHostAvailable, session.execute, query_to_prime) + with pytest.raises(NoHostAvailable): + session.execute(query_to_prime) def test_retry_after_defunct(self): """ @@ -345,7 +349,7 @@ def test_retry_after_defunct(self): response_future = session.execute_async(query_to_prime, timeout=4 * idle_heartbeat_interval + idle_heartbeat_timeout) response_future.result() - self.assertGreater(len(response_future.attempted_hosts), 1) + assert len(response_future.attempted_hosts) > 1 # No error should be raised here since the hosts have been marked # as down and there's still 1 DC available @@ -354,11 +358,11 @@ def test_retry_after_defunct(self): # Might take some time to close the previous connections and reconnect time.sleep(10) - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(cluster) clear_queries() time.sleep(10) - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(cluster) def test_idle_connection_is_not_closed(self): """ @@ -386,7 +390,7 @@ def test_idle_connection_is_not_closed(self): time.sleep(20) - self.assertEqual(listener.hosts_marked_down, []) + assert listener.hosts_marked_down == [] def test_host_is_not_set_to_down_after_query_oto(self): """ @@ -418,10 +422,10 @@ def test_host_is_not_set_to_down_after_query_oto(self): for f in futures: f._event.wait() - self.assertIsInstance(f._final_exception, OperationTimedOut) + assert isinstance(f._final_exception, OperationTimedOut) - self.assertEqual(listener.hosts_marked_down, []) - assert_quiescent_pool_state(self, cluster) + assert listener.hosts_marked_down == [] + assert_quiescent_pool_state(cluster) def test_can_shutdown_connection_subclass(self): start_and_prime_singledc() @@ -461,16 +465,17 @@ def test_driver_recovers_nework_isolation(self): time.sleep((idle_heartbeat_timeout + idle_heartbeat_interval) * 2) for host in cluster.metadata.all_hosts(): - self.assertIn(host, listener.hosts_marked_down) + assert host in listener.hosts_marked_down - self.assertRaises(NoHostAvailable, session.execute, "SELECT * from system.local WHERE key='local'") + with pytest.raises(NoHostAvailable): + session.execute("SELECT * from system.local WHERE key='local'") clear_queries() prime_request(AcceptConnections()) time.sleep(idle_heartbeat_timeout + idle_heartbeat_interval + 2) - self.assertIsNotNone(session.execute("SELECT * from system.local WHERE key='local'")) + assert session.execute("SELECT * from system.local WHERE key='local'") is not None def test_max_in_flight(self): """ Verify we don't exceed max_in_flight when borrowing connections or sending heartbeats """ diff --git a/tests/integration/simulacron/test_empty_column.py b/tests/integration/simulacron/test_empty_column.py index 046aaacf79..589f730f04 100644 --- a/tests/integration/simulacron/test_empty_column.py +++ b/tests/integration/simulacron/test_empty_column.py @@ -81,28 +81,16 @@ def test_empty_columns_with_all_row_factories(self): # Test all row factories self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = named_tuple_factory - self.assertEqual( - list(self.session.execute(query)), - [namedtuple('Row', ['field_0_', 'field_1_'])('testval', 'testval1')] - ) + assert list(self.session.execute(query)) == [namedtuple('Row', ['field_0_', 'field_1_'])('testval', 'testval1')] self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = tuple_factory - self.assertEqual( - list(self.session.execute(query)), - [('testval', 'testval1')] - ) + assert list(self.session.execute(query)) == [('testval', 'testval1')] self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = dict_factory - self.assertEqual( - list(self.session.execute(query)), - [{'': 'testval', ' ': 'testval1'}] - ) + assert list(self.session.execute(query)) == [{'': 'testval', ' ': 'testval1'}] self.cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT].row_factory = ordered_dict_factory - self.assertEqual( - list(self.session.execute(query)), - [OrderedDict((('', 'testval'), (' ', 'testval1')))] - ) + assert list(self.session.execute(query)) == [OrderedDict((('', 'testval'), (' ', 'testval1')))] def test_empty_columns_in_system_schema(self): queries = [ @@ -232,9 +220,9 @@ def test_empty_columns_in_system_schema(self): self.session = self.cluster.connect(wait_for_all_pools=True) table_metadata = self.cluster.metadata.keyspaces['testks'].tables['testtable'] - self.assertEqual(len(table_metadata.columns), 2) - self.assertIn('', table_metadata.columns) - self.assertIn(' ', table_metadata.columns) + assert len(table_metadata.columns) == 2 + assert '' in table_metadata.columns + assert ' ' in table_metadata.columns def test_empty_columns_with_cqlengine(self): self._prime_testtable_query() @@ -249,7 +237,4 @@ class TestModel(Model): empty = columns.Text(db_field='', primary_key=True) space = columns.Text(db_field=' ') - self.assertEqual( - [TestModel(empty='testval', space='testval1')], - list(TestModel.objects.only(['empty', 'space']).all()) - ) + assert [TestModel(empty='testval', space='testval1')] == list(TestModel.objects.only(['empty', 'space']).all()) diff --git a/tests/integration/simulacron/test_endpoint.py b/tests/integration/simulacron/test_endpoint.py index 9e2d91b6d3..5af38a9f6b 100644 --- a/tests/integration/simulacron/test_endpoint.py +++ b/tests/integration/simulacron/test_endpoint.py @@ -76,17 +76,17 @@ class EndPointTests(SimulacronCluster): def test_default_endpoint(self): hosts = self.cluster.metadata.all_hosts() - self.assertEqual(len(hosts), 3) + assert len(hosts) == 3 for host in hosts: - self.assertIsNotNone(host.endpoint) - self.assertIsInstance(host.endpoint, DefaultEndPoint) - self.assertEqual(host.address, host.endpoint.address) - self.assertEqual(host.broadcast_rpc_address, host.endpoint.address) + assert host.endpoint is not None + assert isinstance(host.endpoint, DefaultEndPoint) + assert host.address == host.endpoint.address + assert host.broadcast_rpc_address == host.endpoint.address - self.assertIsInstance(self.cluster.control_connection._connection.endpoint, DefaultEndPoint) - self.assertIsNotNone(self.cluster.control_connection._connection.endpoint) + assert isinstance(self.cluster.control_connection._connection.endpoint, DefaultEndPoint) + assert self.cluster.control_connection._connection.endpoint is not None endpoints = [host.endpoint for host in hosts] - self.assertIn(self.cluster.control_connection._connection.endpoint, endpoints) + assert self.cluster.control_connection._connection.endpoint in endpoints def test_custom_endpoint(self): cluster = Cluster( @@ -98,17 +98,17 @@ def test_custom_endpoint(self): cluster.connect(wait_for_all_pools=True) hosts = cluster.metadata.all_hosts() - self.assertEqual(len(hosts), 3) + assert len(hosts) == 3 for host in hosts: - self.assertIsNotNone(host.endpoint) - self.assertIsInstance(host.endpoint, AddressEndPoint) - self.assertEqual(str(host.endpoint), host.endpoint.address) - self.assertEqual(host.address, host.endpoint.address) - self.assertEqual(host.broadcast_rpc_address, host.endpoint.address) - - self.assertIsInstance(cluster.control_connection._connection.endpoint, AddressEndPoint) - self.assertIsNotNone(cluster.control_connection._connection.endpoint) + assert host.endpoint is not None + assert isinstance(host.endpoint, AddressEndPoint) + assert str(host.endpoint) == host.endpoint.address + assert host.address == host.endpoint.address + assert host.broadcast_rpc_address == host.endpoint.address + + assert isinstance(cluster.control_connection._connection.endpoint, AddressEndPoint) + assert cluster.control_connection._connection.endpoint is not None endpoints = [host.endpoint for host in hosts] - self.assertIn(cluster.control_connection._connection.endpoint, endpoints) + assert cluster.control_connection._connection.endpoint in endpoints cluster.shutdown() diff --git a/tests/integration/simulacron/test_policies.py b/tests/integration/simulacron/test_policies.py index 6d0d081889..3f94a41222 100644 --- a/tests/integration/simulacron/test_policies.py +++ b/tests/integration/simulacron/test_policies.py @@ -27,6 +27,7 @@ from itertools import count from packaging.version import Version +import pytest class BadRoundRobinPolicy(RoundRobinPolicy): @@ -101,30 +102,30 @@ def test_speculative_execution(self): # This LBP should repeat hosts up to around 30 result = self.session.execute(statement, execution_profile='spec_ep_brr') - self.assertEqual(7, len(result.response_future.attempted_hosts)) + assert 7 == len(result.response_future.attempted_hosts) # This LBP should keep host list to 3 result = self.session.execute(statement, execution_profile='spec_ep_rr') - self.assertEqual(3, len(result.response_future.attempted_hosts)) + assert 3 == len(result.response_future.attempted_hosts) # Spec_execution policy should limit retries to 1 result = self.session.execute(statement, execution_profile='spec_ep_rr_lim') - self.assertEqual(2, len(result.response_future.attempted_hosts)) + assert 2 == len(result.response_future.attempted_hosts) # Spec_execution policy should not be used if the query is not idempotent result = self.session.execute(statement_non_idem, execution_profile='spec_ep_brr') - self.assertEqual(1, len(result.response_future.attempted_hosts)) + assert 1 == len(result.response_future.attempted_hosts) # Default policy with non_idem query result = self.session.execute(statement_non_idem, timeout=12) - self.assertEqual(1, len(result.response_future.attempted_hosts)) + assert 1 == len(result.response_future.attempted_hosts) # Should be able to run an idempotent query against default execution policy with no speculative_execution_policy result = self.session.execute(statement, timeout=12) - self.assertEqual(1, len(result.response_future.attempted_hosts)) + assert 1 == len(result.response_future.attempted_hosts) # Test timeout with spec_ex - with self.assertRaises(OperationTimedOut): + with pytest.raises(OperationTimedOut): self.session.execute(statement, execution_profile='spec_ep_rr', timeout=.5) prepared_query_to_prime = "SELECT * FROM test3rf.test where k = ?" @@ -135,11 +136,11 @@ def test_speculative_execution(self): prepared_statement = self.session.prepare(prepared_query_to_prime) # non-idempotent result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') - self.assertEqual(1, len(result.response_future.attempted_hosts)) + assert 1 == len(result.response_future.attempted_hosts) # idempotent prepared_statement.is_idempotent = True result = self.session.execute(prepared_statement, ("0",), execution_profile='spec_ep_brr') - self.assertLess(1, len(result.response_future.attempted_hosts)) + assert 1 < len(result.response_future.attempted_hosts) def test_speculative_and_timeout(self): """ @@ -162,10 +163,10 @@ def test_speculative_and_timeout(self): response_future = self.session.execute_async(statement, execution_profile='spec_ep_brr_lim', timeout=14) response_future._event.wait(16) - self.assertIsInstance(response_future._final_exception, OperationTimedOut) + assert isinstance(response_future._final_exception, OperationTimedOut) # This is because 14 / 4 + 1 = 4 - self.assertEqual(len(response_future.attempted_hosts), 4) + assert len(response_future.attempted_hosts) == 4 def test_delay_can_be_0(self): """ @@ -199,11 +200,11 @@ def patched(*args, **kwargs): stmt = SimpleStatement(query_to_prime) stmt.is_idempotent = True results = session.execute(stmt, execution_profile="spec") - self.assertEqual(len(results.response_future.attempted_hosts), 3) + assert len(results.response_future.attempted_hosts) == 3 # send_request is called number_of_requests times for the speculative request # plus one for the call from the main thread. - self.assertEqual(next(counter), number_of_requests + 1) + assert next(counter) == number_of_requests + 1 class CustomRetryPolicy(RetryPolicy): @@ -306,7 +307,7 @@ def test_retry_policy_ignores_and_rethrows(self): then["write_type"] = "CDC" prime_query(query_to_prime_cdc, rows=None, column_types=None, then=then) - with self.assertRaises(WriteTimeout): + with pytest.raises(WriteTimeout): self.session.execute(query_to_prime_simple) #CDC should be ignored @@ -337,7 +338,7 @@ def test_retry_policy_with_prepared(self): } prime_query(query_to_prime, then=then, rows=None, column_types=None) self.session.execute(query_to_prime) - self.assertEqual(next(counter_policy.write_timeout), 1) + assert next(counter_policy.write_timeout) == 1 counter_policy.reset_counters() query_to_prime_prepared = "SELECT * from simulacron_keyspace.simulacron_table WHERE key = :key" @@ -349,11 +350,11 @@ def test_retry_policy_with_prepared(self): bound_stm = prepared_stmt.bind({"key": "0"}) self.session.execute(bound_stm) - self.assertEqual(next(counter_policy.write_timeout), 1) + assert next(counter_policy.write_timeout) == 1 counter_policy.reset_counters() self.session.execute(prepared_stmt, ("0",)) - self.assertEqual(next(counter_policy.write_timeout), 1) + assert next(counter_policy.write_timeout) == 1 def test_setting_retry_policy_to_statement(self): """ @@ -385,13 +386,13 @@ def test_setting_retry_policy_to_statement(self): prepared_stmt = self.session.prepare(query_to_prime_prepared) prepared_stmt.retry_policy = counter_policy self.session.execute(prepared_stmt, ("0",)) - self.assertEqual(next(counter_policy.write_timeout), 1) + assert next(counter_policy.write_timeout) == 1 counter_policy.reset_counters() bound_stmt = prepared_stmt.bind({"key": "0"}) bound_stmt.retry_policy = counter_policy self.session.execute(bound_stmt) - self.assertEqual(next(counter_policy.write_timeout), 1) + assert next(counter_policy.write_timeout) == 1 def test_retry_policy_on_request_error(self): """ @@ -438,12 +439,12 @@ def test_retry_policy_on_request_error(self): prime_query(query_to_prime, then=prime_error, rows=None, column_types=None) rf = self.session.execute_async(query_to_prime) - with self.assertRaises(exc): + with pytest.raises(exc): rf.result() - self.assertEqual(len(rf.attempted_hosts), 1) # no retry + assert len(rf.attempted_hosts) == 1 # no retry - self.assertEqual(next(retry_policy.request_error), 4) + assert next(retry_policy.request_error) == 4 # Test that by default, retry on next host retry_policy = RetryPolicy() @@ -455,7 +456,7 @@ def test_retry_policy_on_request_error(self): prime_query(query_to_prime, then=e, rows=None, column_types=None) rf = self.session.execute_async(query_to_prime) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): rf.result() - self.assertEqual(len(rf.attempted_hosts), 3) # all 3 nodes failed + assert len(rf.attempted_hosts) == 3 # all 3 nodes failed diff --git a/tests/integration/standard/column_encryption/test_policies.py b/tests/integration/standard/column_encryption/test_policies.py index 36376d689b..9a1d186895 100644 --- a/tests/integration/standard/column_encryption/test_policies.py +++ b/tests/integration/standard/column_encryption/test_policies.py @@ -59,14 +59,14 @@ def test_end_to_end_prepared(self): # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. (encrypted,unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() - self.assertEqual(expected, encrypted) - self.assertEqual(expected, unencrypted) + assert expected == encrypted + assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") (encrypted,unencrypted) = session.execute(prepared, [expected]).one() - self.assertEqual(expected, encrypted) - self.assertEqual(expected, unencrypted) + assert expected == encrypted + assert expected == unencrypted def test_end_to_end_simple(self): @@ -80,21 +80,21 @@ def test_end_to_end_simple(self): # Use encode_and_encrypt helper function to populate date for i in range(1,100): - self.assertIsNotNone(i) + assert i is not None encrypted = cl_policy.encode_and_encrypt(col_desc, i) session.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. (encrypted,unencrypted) = session.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() - self.assertEqual(expected, encrypted) - self.assertEqual(expected, unencrypted) + assert expected == encrypted + assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well prepared = session.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") (encrypted,unencrypted) = session.execute(prepared, [expected]).one() - self.assertEqual(expected, encrypted) - self.assertEqual(expected, unencrypted) + assert expected == encrypted + assert expected == unencrypted def test_end_to_end_different_cle_contexts_different_ivs(self): """ @@ -119,7 +119,7 @@ def test_end_to_end_different_cle_contexts_different_ivs(self): # Use encode_and_encrypt helper function to populate date for i in range(1,100): - self.assertIsNotNone(i) + assert i is not None encrypted = cl_policy1.encode_and_encrypt(col_desc1, i) session1.execute("insert into foo.bar (encrypted, unencrypted) values (%s,%s)", (encrypted, i)) session1.shutdown() @@ -129,15 +129,15 @@ def test_end_to_end_different_cle_contexts_different_ivs(self): # that would entail not re-using any cached ciphers AES256ColumnEncryptionPolicy._build_cipher.cache_clear() cache_info = cl_policy1.cache_info() - self.assertEqual(cache_info.currsize, 0) + assert cache_info.currsize == 0 iv2 = os.urandom(AES256_BLOCK_SIZE_BYTES) (_, cl_policy2) = self._create_policy(key, iv=iv2) cluster2 = TestCluster(column_encryption_policy=cl_policy2) session2 = cluster2.connect() (encrypted,unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() - self.assertEqual(expected, encrypted) - self.assertEqual(expected, unencrypted) + assert expected == encrypted + assert expected == unencrypted def test_end_to_end_different_cle_contexts_different_policies(self): """ @@ -162,10 +162,10 @@ def test_end_to_end_different_cle_contexts_different_policies(self): # A straight select from the database will now return the decrypted bits. We select both encrypted and unencrypted # values here to confirm that we don't interfere with regular processing of unencrypted vals. (encrypted,unencrypted) = session2.execute("select encrypted, unencrypted from foo.bar where unencrypted = %s allow filtering", (expected,)).one() - self.assertEqual(cl_policy.encode_and_encrypt(col_desc, expected), encrypted) - self.assertEqual(expected, unencrypted) + assert cl_policy.encode_and_encrypt(col_desc, expected) == encrypted + assert expected == unencrypted # Confirm the same behaviour from a subsequent prepared statement as well prepared = session2.prepare("select encrypted, unencrypted from foo.bar where unencrypted = ? allow filtering") (encrypted,unencrypted) = session2.execute(prepared, [expected]).one() - self.assertEqual(cl_policy.encode_and_encrypt(col_desc, expected), encrypted) + assert cl_policy.encode_and_encrypt(col_desc, expected) == encrypted diff --git a/tests/integration/standard/test_authentication.py b/tests/integration/standard/test_authentication.py index 122df55a02..eb8019bf65 100644 --- a/tests/integration/standard/test_authentication.py +++ b/tests/integration/standard/test_authentication.py @@ -24,6 +24,7 @@ from tests.integration.util import assert_quiescent_pool_state import unittest +import pytest log = logging.getLogger(__name__) @@ -105,56 +106,52 @@ def test_auth_connect(self): cluster = self.cluster_as(user, passwd) session = cluster.connect(wait_for_all_pools=True) try: - self.assertTrue(session.execute("SELECT release_version FROM system.local WHERE key='local'")) - assert_quiescent_pool_state(self, cluster, wait=1) + assert session.execute("SELECT release_version FROM system.local WHERE key='local'") + assert_quiescent_pool_state(cluster, wait=1) for pool in session.get_pools(): connection, _ = pool.borrow_connection(timeout=0) - self.assertEqual(connection.authenticator.server_authenticator_class, 'org.apache.cassandra.auth.PasswordAuthenticator') + assert connection.authenticator.server_authenticator_class == 'org.apache.cassandra.auth.PasswordAuthenticator' pool.return_connection(connection) finally: cluster.shutdown() finally: root_session.execute('DROP USER %s', user) - assert_quiescent_pool_state(self, root_session.cluster, wait=1) + assert_quiescent_pool_state(root_session.cluster, wait=1) root_session.cluster.shutdown() def test_connect_wrong_pwd(self): cluster = self.cluster_as('cassandra', 'wrong_pass') try: - self.assertRaisesRegex(NoHostAvailable, - '.*AuthenticationFailed.', - cluster.connect) - assert_quiescent_pool_state(self, cluster) + with pytest.raises(NoHostAvailable, match='.*AuthenticationFailed.'): + cluster.connect() + assert_quiescent_pool_state(cluster) finally: cluster.shutdown() def test_connect_wrong_username(self): cluster = self.cluster_as('wrong_user', 'cassandra') try: - self.assertRaisesRegex(NoHostAvailable, - '.*AuthenticationFailed.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) + with pytest.raises(NoHostAvailable, match='.*AuthenticationFailed.*'): + cluster.connect() + assert_quiescent_pool_state(cluster) finally: cluster.shutdown() def test_connect_empty_pwd(self): cluster = self.cluster_as('Cassandra', '') try: - self.assertRaisesRegex(NoHostAvailable, - '.*AuthenticationFailed.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) + with pytest.raises(NoHostAvailable, match='.*AuthenticationFailed.*'): + cluster.connect() + assert_quiescent_pool_state(cluster) finally: cluster.shutdown() def test_connect_no_auth_provider(self): cluster = TestCluster() try: - self.assertRaisesRegex(NoHostAvailable, - '.*AuthenticationFailed.*', - cluster.connect) - assert_quiescent_pool_state(self, cluster) + with pytest.raises(NoHostAvailable, match='.*AuthenticationFailed.*'): + cluster.connect() + assert_quiescent_pool_state(cluster) finally: cluster.shutdown() @@ -184,8 +181,9 @@ def test_host_passthrough(self): provider = SaslAuthProvider(**sasl_kwargs) host = 'thehostname' authenticator = provider.new_authenticator(host) - self.assertEqual(authenticator.sasl.host, host) + assert authenticator.sasl.host == host def test_host_rejected(self): sasl_kwargs = {'host': 'something'} - self.assertRaises(ValueError, SaslAuthProvider, **sasl_kwargs) + with pytest.raises(ValueError): + SaslAuthProvider(**sasl_kwargs) diff --git a/tests/integration/standard/test_authentication_misconfiguration.py b/tests/integration/standard/test_authentication_misconfiguration.py index 2b02664c3f..9ad4ad997d 100644 --- a/tests/integration/standard/test_authentication_misconfiguration.py +++ b/tests/integration/standard/test_authentication_misconfiguration.py @@ -40,7 +40,7 @@ def test_connect_no_auth_provider(self): cluster.connect() cluster.refresh_nodes() down_hosts = [host for host in cluster.metadata.all_hosts() if not host.is_up] - self.assertEqual(len(down_hosts), 1) + assert len(down_hosts) == 1 cluster.shutdown() @classmethod diff --git a/tests/integration/standard/test_client_warnings.py b/tests/integration/standard/test_client_warnings.py index ce5332a59f..781b5b7860 100644 --- a/tests/integration/standard/test_client_warnings.py +++ b/tests/integration/standard/test_client_warnings.py @@ -19,6 +19,7 @@ from tests.integration import (use_singledc, PROTOCOL_VERSION, local, TestCluster, requires_custom_payload, xfail_scylla) +from tests.util import assertRegex, assertDictEqual def setup_module(): @@ -70,8 +71,8 @@ def test_warning_basic(self): """ future = self.session.execute_async(self.warn_batch) future.result() - self.assertEqual(len(future.warnings), 1) - self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') + assert len(future.warnings) == 1 + assertRegex(future.warnings[0], 'Batch.*exceeding.*') def test_warning_with_trace(self): """ @@ -86,9 +87,9 @@ def test_warning_with_trace(self): """ future = self.session.execute_async(self.warn_batch, trace=True) future.result() - self.assertEqual(len(future.warnings), 1) - self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') - self.assertIsNotNone(future.get_query_trace()) + assert len(future.warnings) == 1 + assertRegex(future.warnings[0], 'Batch.*exceeding.*') + assert future.get_query_trace() is not None @local @requires_custom_payload @@ -106,9 +107,9 @@ def test_warning_with_custom_payload(self): payload = {'key': b'value'} future = self.session.execute_async(self.warn_batch, custom_payload=payload) future.result() - self.assertEqual(len(future.warnings), 1) - self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') - self.assertDictEqual(future.custom_payload, payload) + assert len(future.warnings) == 1 + assertRegex(future.warnings[0], 'Batch.*exceeding.*') + assertDictEqual(future.custom_payload, payload) @local @requires_custom_payload @@ -126,7 +127,7 @@ def test_warning_with_trace_and_custom_payload(self): payload = {'key': b'value'} future = self.session.execute_async(self.warn_batch, trace=True, custom_payload=payload) future.result() - self.assertEqual(len(future.warnings), 1) - self.assertRegex(future.warnings[0], 'Batch.*exceeding.*') - self.assertIsNotNone(future.get_query_trace()) - self.assertDictEqual(future.custom_payload, payload) + assert len(future.warnings) == 1 + assertRegex(future.warnings[0], 'Batch.*exceeding.*') + assert future.get_query_trace() is not None + assertDictEqual(future.custom_payload, payload) diff --git a/tests/integration/standard/test_cluster.py b/tests/integration/standard/test_cluster.py index 503b9304b3..c83f454a0f 100644 --- a/tests/integration/standard/test_cluster.py +++ b/tests/integration/standard/test_cluster.py @@ -24,6 +24,7 @@ import warnings from packaging.version import Version import os +import pytest import cassandra from cassandra.cluster import NoHostAvailable, ExecutionProfile, EXEC_PROFILE_DEFAULT, ControlConnection, Cluster @@ -44,6 +45,7 @@ get_unsupported_upper_protocol, protocolv6, local, CASSANDRA_IP, greaterthanorequalcass30, \ lessthanorequalcass40, TestCluster, PROTOCOL_VERSION, xfail_scylla, incorrect_test from tests.integration.util import assert_quiescent_pool_state +from tests.util import assertListEqual import sys log = logging.getLogger(__name__) @@ -87,9 +89,9 @@ def test_ignored_host_up(self): cluster.connect() for host in cluster.metadata.all_hosts(): if str(host) == "127.0.0.1:9042": - self.assertTrue(host.is_up) + assert host.is_up else: - self.assertIsNone(host.is_up) + assert host.is_up is None cluster.shutdown() @local @@ -104,7 +106,7 @@ def test_host_resolution(self): @test_category connection """ cluster = TestCluster(contact_points=["localhost"], connect_timeout=1) - self.assertTrue(DefaultEndPoint('127.0.0.1') in cluster.endpoints_resolved) + assert DefaultEndPoint('127.0.0.1') in cluster.endpoints_resolved @local def test_host_duplication(self): @@ -122,11 +124,11 @@ def test_host_duplication(self): connect_timeout=1 ) cluster.connect(wait_for_all_pools=True) - self.assertEqual(len(cluster.metadata.all_hosts()), 3) + assert len(cluster.metadata.all_hosts()) == 3 cluster.shutdown() cluster = TestCluster(contact_points=["127.0.0.1", "localhost"], connect_timeout=1) cluster.connect(wait_for_all_pools=True) - self.assertEqual(len(cluster.metadata.all_hosts()), 3) + assert len(cluster.metadata.all_hosts()) == 3 cluster.shutdown() @local @@ -150,7 +152,7 @@ def test_raise_error_on_control_connection_timeout(self): get_node(1).pause() cluster = TestCluster(contact_points=['127.0.0.1'], connect_timeout=1) - with self.assertRaisesRegex(NoHostAvailable, r"OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): + with pytest.raises(NoHostAvailable, match=r"OperationTimedOut\('errors=Timed out creating connection \(1 seconds\)"): cluster.connect() cluster.shutdown() @@ -168,7 +170,7 @@ def test_basic(self): CREATE KEYSPACE clustertests WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'} """) - self.assertFalse(result) + assert not result result = execute_with_long_wait_retry(session, """ @@ -179,16 +181,16 @@ def test_basic(self): PRIMARY KEY (a, b) ) """) - self.assertFalse(result) + assert not result result = session.execute( """ INSERT INTO clustertests.cf0 (a, b, c) VALUES ('a', 'b', 'c') """) - self.assertFalse(result) + assert not result result = session.execute("SELECT * FROM clustertests.cf0") - self.assertEqual([('a', 'b', 'c')], result) + assert [('a', 'b', 'c')] == result execute_with_long_wait_retry(session, "DROP KEYSPACE clustertests") @@ -218,13 +220,13 @@ def cleanup(): # Test with empty list self.cluster_to_shutdown = TestCluster(contact_points=[]) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): self.cluster_to_shutdown.connect() self.cluster_to_shutdown.shutdown() # Test with only invalid self.cluster_to_shutdown = TestCluster(contact_points=('1.2.3.4',)) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): self.cluster_to_shutdown.connect() self.cluster_to_shutdown.shutdown() @@ -250,35 +252,35 @@ def test_protocol_negotiation(self): """ cluster = Cluster() - self.assertLessEqual(cluster.protocol_version, cassandra.ProtocolVersion.MAX_SUPPORTED) + assert cluster.protocol_version <= cassandra.ProtocolVersion.MAX_SUPPORTED session = cluster.connect() updated_protocol_version = session._protocol_version updated_cluster_version = cluster.protocol_version # Make sure the correct protocol was selected by default if CASSANDRA_VERSION >= Version('4.0-beta5'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V5) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V5) + assert updated_protocol_version == cassandra.ProtocolVersion.V5 + assert updated_cluster_version == cassandra.ProtocolVersion.V5 elif CASSANDRA_VERSION >= Version('4.0-a'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + assert updated_protocol_version == cassandra.ProtocolVersion.V4 + assert updated_cluster_version == cassandra.ProtocolVersion.V4 elif CASSANDRA_VERSION >= Version('3.11'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + assert updated_protocol_version == cassandra.ProtocolVersion.V4 + assert updated_cluster_version == cassandra.ProtocolVersion.V4 elif CASSANDRA_VERSION >= Version('3.0'): - self.assertEqual(updated_protocol_version, cassandra.ProtocolVersion.V4) - self.assertEqual(updated_cluster_version, cassandra.ProtocolVersion.V4) + assert updated_protocol_version == cassandra.ProtocolVersion.V4 + assert updated_cluster_version == cassandra.ProtocolVersion.V4 elif CASSANDRA_VERSION >= Version('2.2'): - self.assertEqual(updated_protocol_version, 4) - self.assertEqual(updated_cluster_version, 4) + assert updated_protocol_version == 4 + assert updated_cluster_version == 4 elif CASSANDRA_VERSION >= Version('2.1'): - self.assertEqual(updated_protocol_version, 3) - self.assertEqual(updated_cluster_version, 3) + assert updated_protocol_version == 3 + assert updated_cluster_version == 3 elif CASSANDRA_VERSION >= Version('2.0'): - self.assertEqual(updated_protocol_version, 2) - self.assertEqual(updated_cluster_version, 2) + assert updated_protocol_version == 2 + assert updated_cluster_version == 2 else: - self.assertEqual(updated_protocol_version, 1) - self.assertEqual(updated_cluster_version, 1) + assert updated_protocol_version == 1 + assert updated_cluster_version == 1 cluster.shutdown() @@ -308,7 +310,7 @@ def test_invalid_protocol_negotation(self): log.debug('got upper_bound of {}'.format(upper_bound)) if upper_bound is not None: cluster = TestCluster(protocol_version=upper_bound) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): cluster.connect() cluster.shutdown() @@ -316,7 +318,7 @@ def test_invalid_protocol_negotation(self): log.debug('got lower_bound of {}'.format(lower_bound)) if lower_bound is not None: cluster = TestCluster(protocol_version=lower_bound) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): cluster.connect() cluster.shutdown() @@ -331,15 +333,15 @@ def test_connect_on_keyspace(self): """ INSERT INTO test1rf.test (k, v) VALUES (8889, 8889) """) - self.assertFalse(result) + assert not result result = session.execute("SELECT * FROM test1rf.test") - self.assertEqual([(8889, 8889)], result, "Rows in ResultSet are {0}".format(result.current_rows)) + assert [(8889, 8889)] == result, "Rows in ResultSet are {0}".format(result.current_rows) # test_connect_on_keyspace session2 = cluster.connect('test1rf') result2 = session2.execute("SELECT * FROM test") - self.assertEqual(result, result2) + assert result == result2 cluster.shutdown() def test_set_keyspace_twice(self): @@ -366,31 +368,37 @@ def test_connect_to_already_shutdown_cluster(self): """ cluster = TestCluster() cluster.shutdown() - self.assertRaises(Exception, cluster.connect) + with pytest.raises(Exception): + cluster.connect() def test_auth_provider_is_callable(self): """ Ensure that auth_providers are always callable """ - self.assertRaises(TypeError, Cluster, auth_provider=1, protocol_version=1) + with pytest.raises(TypeError): + Cluster(auth_provider=1, protocol_version=1) c = TestCluster(protocol_version=1) - self.assertRaises(TypeError, setattr, c, 'auth_provider', 1) + with pytest.raises(TypeError): + setattr(c, 'auth_provider', 1) def test_v2_auth_provider(self): """ Check for v2 auth_provider compliance """ bad_auth_provider = lambda x: {'username': 'foo', 'password': 'bar'} - self.assertRaises(TypeError, Cluster, auth_provider=bad_auth_provider, protocol_version=2) + with pytest.raises(TypeError): + Cluster(auth_provider=bad_auth_provider, protocol_version=2) c = TestCluster(protocol_version=2) - self.assertRaises(TypeError, setattr, c, 'auth_provider', bad_auth_provider) + with pytest.raises(TypeError): + setattr(c, 'auth_provider', bad_auth_provider) def test_conviction_policy_factory_is_callable(self): """ Ensure that conviction_policy_factory are always callable """ - self.assertRaises(ValueError, Cluster, conviction_policy_factory=1) + with pytest.raises(ValueError): + Cluster(conviction_policy_factory=1) def test_connect_to_bad_hosts(self): """ @@ -400,7 +408,8 @@ def test_connect_to_bad_hosts(self): cluster = TestCluster(contact_points=['127.1.2.9', '127.1.2.10'], protocol_version=PROTOCOL_VERSION) - self.assertRaises(NoHostAvailable, cluster.connect) + with pytest.raises(NoHostAvailable): + cluster.connect() def test_refresh_schema(self): cluster = TestCluster() @@ -409,8 +418,8 @@ def test_refresh_schema(self): original_meta = cluster.metadata.keyspaces # full schema refresh, with wait cluster.refresh_schema_metadata() - self.assertIsNot(original_meta, cluster.metadata.keyspaces) - self.assertEqual(original_meta, cluster.metadata.keyspaces) + assert original_meta is not cluster.metadata.keyspaces + assert original_meta == cluster.metadata.keyspaces cluster.shutdown() @@ -424,10 +433,10 @@ def test_refresh_schema_keyspace(self): # only refresh one keyspace cluster.refresh_keyspace_metadata('system') current_meta = cluster.metadata.keyspaces - self.assertIs(original_meta, current_meta) + assert original_meta is current_meta current_system_meta = current_meta['system'] - self.assertIsNot(original_system_meta, current_system_meta) - self.assertEqual(original_system_meta.as_cql_query(), current_system_meta.as_cql_query()) + assert original_system_meta is not current_system_meta + assert original_system_meta.as_cql_query() == current_system_meta.as_cql_query() cluster.shutdown() def test_refresh_schema_table(self): @@ -443,10 +452,10 @@ def test_refresh_schema_table(self): current_meta = cluster.metadata.keyspaces current_system_meta = current_meta['system'] current_system_schema_meta = current_system_meta.tables['local'] - self.assertIs(original_meta, current_meta) - self.assertIs(original_system_meta, current_system_meta) - self.assertIsNot(original_system_schema_meta, current_system_schema_meta) - self.assertEqual(original_system_schema_meta.as_cql_query(), current_system_schema_meta.as_cql_query()) + assert original_meta is current_meta + assert original_system_meta is current_system_meta + assert original_system_schema_meta is not current_system_schema_meta + assert original_system_schema_meta.as_cql_query() == current_system_schema_meta.as_cql_query() cluster.shutdown() def test_refresh_schema_type(self): @@ -473,10 +482,10 @@ def test_refresh_schema_type(self): current_meta = cluster.metadata.keyspaces current_test1rf_meta = current_meta[keyspace_name] current_type_meta = current_test1rf_meta.user_types[type_name] - self.assertIs(original_meta, current_meta) - self.assertEqual(original_test1rf_meta.export_as_string(), current_test1rf_meta.export_as_string()) - self.assertIsNot(original_type_meta, current_type_meta) - self.assertEqual(original_type_meta.as_cql_query(), current_type_meta.as_cql_query()) + assert original_meta is current_meta + assert original_test1rf_meta.export_as_string() == current_test1rf_meta.export_as_string() + assert original_type_meta is not current_type_meta + assert original_type_meta.as_cql_query() == current_type_meta.as_cql_query() cluster.shutdown() @local @@ -498,24 +507,25 @@ def patched_wait_for_responses(*args, **kwargs): # cluster agreement wait exceeded c = TestCluster(max_schema_agreement_wait=agreement_timeout) c.connect() - self.assertTrue(c.metadata.keyspaces) + assert c.metadata.keyspaces # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() - self.assertRaisesRegex(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata) + with pytest.raises(Exception, match=r"Schema metadata was not refreshed.*"): + c.refresh_schema_metadata() end_time = time.time() - self.assertGreaterEqual(end_time - start_time, agreement_timeout) - self.assertIs(original_meta, c.metadata.keyspaces) + assert end_time - start_time >= agreement_timeout + assert original_meta is c.metadata.keyspaces # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata(max_schema_agreement_wait=0) end_time = time.time() - self.assertLess(end_time - start_time, agreement_timeout) - self.assertIsNot(original_meta, c.metadata.keyspaces) - self.assertEqual(original_meta, c.metadata.keyspaces) + assert end_time - start_time < agreement_timeout + assert original_meta is not c.metadata.keyspaces + assert original_meta == c.metadata.keyspaces c.shutdown() @@ -525,26 +535,26 @@ def patched_wait_for_responses(*args, **kwargs): start_time = time.time() s = c.connect() end_time = time.time() - self.assertLess(end_time - start_time, refresh_threshold) - self.assertTrue(c.metadata.keyspaces) + assert end_time - start_time < refresh_threshold + assert c.metadata.keyspaces # cluster agreement wait used for refresh original_meta = c.metadata.keyspaces start_time = time.time() c.refresh_schema_metadata() end_time = time.time() - self.assertLess(end_time - start_time, refresh_threshold) - self.assertIsNot(original_meta, c.metadata.keyspaces) - self.assertEqual(original_meta, c.metadata.keyspaces) + assert end_time - start_time < refresh_threshold + assert original_meta is not c.metadata.keyspaces + assert original_meta == c.metadata.keyspaces # refresh wait overrides cluster value original_meta = c.metadata.keyspaces start_time = time.time() - self.assertRaisesRegex(Exception, r"Schema metadata was not refreshed.*", c.refresh_schema_metadata, - max_schema_agreement_wait=agreement_timeout) + with pytest.raises(Exception, match=r"Schema metadata was not refreshed.*"): + c.refresh_schema_metadata(max_schema_agreement_wait=agreement_timeout) end_time = time.time() - self.assertGreaterEqual(end_time - start_time, agreement_timeout) - self.assertIs(original_meta, c.metadata.keyspaces) + assert end_time - start_time >= agreement_timeout + assert original_meta is c.metadata.keyspaces c.shutdown() def test_trace(self): @@ -566,7 +576,7 @@ def test_trace(self): query = "SELECT * FROM system.local WHERE key='local'" statement = SimpleStatement(query) result = session.execute(statement) - self.assertIsNone(result.get_query_trace()) + assert result.get_query_trace() is None statement2 = SimpleStatement(query) future = session.execute_async(statement2, trace=True) @@ -576,7 +586,7 @@ def test_trace(self): statement2 = SimpleStatement(query) future = session.execute_async(statement2) future.result() - self.assertIsNone(future.get_query_trace()) + assert future.get_query_trace() is None prepared = session.prepare("SELECT * FROM system.local WHERE key='local'") future = session.execute_async(prepared, parameters=(), trace=True) @@ -642,7 +652,7 @@ def test_one_returns_none(self): """ with TestCluster() as cluster: session = cluster.connect() - self.assertIsNone(session.execute("SELECT * from system.local WHERE key='madeup_key'").one()) + assert session.execute("SELECT * from system.local WHERE key='madeup_key'").one() is None def test_string_coverage(self): """ @@ -656,11 +666,11 @@ def test_string_coverage(self): statement = SimpleStatement(query) future = session.execute_async(statement) - self.assertIn(query, str(future)) + assert query in str(future) future.result() - self.assertIn(query, str(future)) - self.assertIn('result', str(future)) + assert query in str(future) + assert 'result' in str(future) cluster.shutdown() def test_can_connect_with_plainauth(self): @@ -708,15 +718,12 @@ def _warning_are_issued_when_auth(self, auth_provider): with MockLoggingHandler().set_module_name(connection.__name__) as mock_handler: with TestCluster(auth_provider=auth_provider) as cluster: session = cluster.connect() - self.assertIsNotNone(session.execute("SELECT * from system.local WHERE key='local'")) + assert session.execute("SELECT * from system.local WHERE key='local'") is not None # Three conenctions to nodes plus the control connection auth_warning = mock_handler.get_message_count('warning', "An authentication challenge was not sent") - self.assertGreaterEqual(auth_warning, 4) - self.assertEqual( - auth_warning, - mock_handler.get_message_count("debug", "Got ReadyMessage on new connection") - ) + assert auth_warning >= 4 + assert auth_warning == mock_handler.get_message_count("debug", "Got ReadyMessage on new connection") def test_idle_heartbeat(self): interval = 2 @@ -732,7 +739,7 @@ def test_idle_heartbeat(self): for h in cluster.get_connection_holders(): for c in h.get_connections(): # make sure none are idle (should have startup messages - self.assertFalse(c.is_idle) + assert not c.is_idle with c.lock: connection_request_ids[id(c)] = deque(c.request_ids) # copy of request ids @@ -746,45 +753,45 @@ def test_idle_heartbeat(self): expected_ids = connection_request_ids[id(c)] expected_ids.rotate(-1) with c.lock: - self.assertListEqual(list(c.request_ids), list(expected_ids)) + assertListEqual(list(c.request_ids), list(expected_ids)) # assert idle status - self.assertTrue(all(c.is_idle for c in connections)) + assert all(c.is_idle for c in connections) # send messages on all connections statements_and_params = [("SELECT release_version FROM system.local WHERE key='local'", ())] * len(cluster.metadata.all_hosts()) results = execute_concurrent(session, statements_and_params) for success, result in results: - self.assertTrue(success) + assert success # assert not idle status - self.assertFalse(any(c.is_idle if not c.is_control_connection else False for c in connections)) + assert not any(c.is_idle if not c.is_control_connection else False for c in connections) # holders include session pools and cc holders = cluster.get_connection_holders() - self.assertIn(cluster.control_connection, holders) - self.assertEqual(len(holders), len(cluster.metadata.all_hosts()) + 1) # hosts pools, 1 for cc + assert cluster.control_connection in holders + assert len(holders) == len(cluster.metadata.all_hosts()) + 1 # hosts pools, 1 for cc # include additional sessions session2 = cluster.connect(wait_for_all_pools=True) holders = cluster.get_connection_holders() - self.assertIn(cluster.control_connection, holders) - self.assertEqual(len(holders), 2 * len(cluster.metadata.all_hosts()) + 1) # 2 sessions' hosts pools, 1 for cc + assert cluster.control_connection in holders + assert len(holders) == 2 * len(cluster.metadata.all_hosts()) + 1 # 2 sessions' hosts pools, 1 for cc cluster._idle_heartbeat.stop() cluster._idle_heartbeat.join() - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(cluster) cluster.shutdown() @patch('cassandra.cluster.Cluster.idle_heartbeat_interval', new=0.1) def test_idle_heartbeat_disabled(self): - self.assertTrue(Cluster.idle_heartbeat_interval) + assert Cluster.idle_heartbeat_interval # heartbeat disabled with '0' cluster = TestCluster(idle_heartbeat_interval=0) - self.assertEqual(cluster.idle_heartbeat_interval, 0) + assert cluster.idle_heartbeat_interval == 0 session = cluster.connect() # let two heatbeat intervals pass (first one had startup messages in it) @@ -793,7 +800,7 @@ def test_idle_heartbeat_disabled(self): connections = [c for holders in cluster.get_connection_holders() for c in holders.get_connections()] # assert not idle status (should never get reset because there is not heartbeat) - self.assertFalse(any(c.is_idle for c in connections)) + assert not any(c.is_idle for c in connections) cluster.shutdown() @@ -805,10 +812,10 @@ def test_pool_management(self): # prepare p = session.prepare("SELECT * FROM system.local WHERE key=?") - self.assertTrue(session.execute(p, ('local',))) + assert session.execute(p, ('local',)) # simple - self.assertTrue(session.execute("SELECT * FROM system.local WHERE key='local'")) + assert session.execute("SELECT * FROM system.local WHERE key='local'") # set keyspace session.set_keyspace('system') @@ -822,7 +829,7 @@ def test_pool_management(self): cluster.refresh_schema_metadata() cluster.refresh_schema_metadata(max_schema_agreement_wait=0) - assert_quiescent_pool_state(self, cluster) + assert_quiescent_pool_state(cluster) cluster.shutdown() @@ -852,7 +859,7 @@ def test_profile_load_balancing(self): for _ in expected_hosts: rs = session.execute(query) queried_hosts.add(rs.response_future._current_host) - self.assertEqual(queried_hosts, expected_hosts) + assert queried_hosts == expected_hosts # by name we should only hit the one expected_hosts = set(h for h in cluster.metadata.all_hosts() if h.address == CASSANDRA_IP) @@ -860,13 +867,13 @@ def test_profile_load_balancing(self): for _ in cluster.metadata.all_hosts(): rs = session.execute(query, execution_profile='node1') queried_hosts.add(rs.response_future._current_host) - self.assertEqual(queried_hosts, expected_hosts) + assert queried_hosts == expected_hosts # use a copied instance and override the row factory # assert last returned value can be accessed as a namedtuple so we can prove something different named_tuple_row = rs.one() - self.assertIsInstance(named_tuple_row, tuple) - self.assertTrue(named_tuple_row.release_version) + assert isinstance(named_tuple_row, tuple) + assert named_tuple_row.release_version tmp_profile = copy(node1) tmp_profile.row_factory = tuple_factory @@ -874,26 +881,22 @@ def test_profile_load_balancing(self): for _ in cluster.metadata.all_hosts(): rs = session.execute(query, execution_profile=tmp_profile) queried_hosts.add(rs.response_future._current_host) - self.assertEqual(queried_hosts, expected_hosts) + assert queried_hosts == expected_hosts tuple_row = rs.one() - self.assertIsInstance(tuple_row, tuple) - with self.assertRaises(AttributeError): + assert isinstance(tuple_row, tuple) + with pytest.raises(AttributeError): tuple_row.release_version # make sure original profile is not impacted - self.assertTrue(session.execute(query, execution_profile='node1').one().release_version) + assert session.execute(query, execution_profile='node1').one().release_version def test_setting_lbp_legacy(self): cluster = TestCluster() self.addCleanup(cluster.shutdown) cluster.load_balancing_policy = RoundRobinPolicy() - self.assertEqual( - list(cluster.load_balancing_policy.make_query_plan()), [] - ) + assert list(cluster.load_balancing_policy.make_query_plan()) == [] cluster.connect() - self.assertNotEqual( - list(cluster.load_balancing_policy.make_query_plan()), [] - ) + assert list(cluster.load_balancing_policy.make_query_plan()) != [] def test_profile_lb_swap(self): """ @@ -925,7 +928,7 @@ def test_profile_lb_swap(self): rs = session.execute(query, execution_profile='rr2') rr2_queried_hosts.add(rs.response_future._current_host) - self.assertEqual(rr2_queried_hosts, rr1_queried_hosts) + assert rr2_queried_hosts == rr1_queried_hosts def test_ta_lbp(self): """ @@ -962,7 +965,7 @@ def test_clone_shared_lbp(self): exec_profiles = {'rr1': rr1} with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect(wait_for_all_pools=True) - self.assertGreater(len(cluster.metadata.all_hosts()), 1, "We only have one host connected at this point") + assert len(cluster.metadata.all_hosts()) > 1, "We only have one host connected at this point" rr1_clone = session.execution_profile_clone_update('rr1', row_factory=tuple_factory) cluster.add_execution_profile("rr1_clone", rr1_clone) @@ -972,7 +975,7 @@ def test_clone_shared_lbp(self): rr1_queried_hosts.add(rs.response_future._current_host) rs = session.execute(query, execution_profile='rr1_clone') rr1_clone_queried_hosts.add(rs.response_future._current_host) - self.assertNotEqual(rr1_clone_queried_hosts, rr1_queried_hosts) + assert rr1_clone_queried_hosts != rr1_queried_hosts def test_missing_exec_prof(self): """ @@ -990,7 +993,7 @@ def test_missing_exec_prof(self): exec_profiles = {'rr1': rr1, 'rr2': rr2} with TestCluster(execution_profiles=exec_profiles) as cluster: session = cluster.connect() - with self.assertRaises(ValueError): + with pytest.raises(ValueError): session.execute(query, execution_profile='rr3') @local @@ -1019,8 +1022,8 @@ def test_profile_pool_management(self): session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() # there are more hosts, but we connected to the ones in the lbp aggregate - self.assertGreater(len(cluster.metadata.all_hosts()), 2) - self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2'))) + assert len(cluster.metadata.all_hosts()) > 2 + assert set(h.address for h in pools) == set(('127.0.0.1', '127.0.0.2')) # dynamically update pools on add node3 = ExecutionProfile( @@ -1030,7 +1033,7 @@ def test_profile_pool_management(self): ) cluster.add_execution_profile('node3', node3) pools = session.get_pool_state() - self.assertEqual(set(h.address for h in pools), set(('127.0.0.1', '127.0.0.2', '127.0.0.3'))) + assert set(h.address for h in pools) == set(('127.0.0.1', '127.0.0.2', '127.0.0.3')) @local def test_add_profile_timeout(self): @@ -1053,8 +1056,8 @@ def test_add_profile_timeout(self): with TestCluster(execution_profiles={EXEC_PROFILE_DEFAULT: node1}) as cluster: session = cluster.connect(wait_for_all_pools=True) pools = session.get_pool_state() - self.assertGreater(len(cluster.metadata.all_hosts()), 2) - self.assertEqual(set(h.address for h in pools), set(('127.0.0.1',))) + assert len(cluster.metadata.all_hosts()) > 2 + assert set(h.address for h in pools) == set(('127.0.0.1',)) node2 = ExecutionProfile( load_balancing_policy=HostFilterPolicy( @@ -1064,13 +1067,13 @@ def test_add_profile_timeout(self): start = time.time() try: - self.assertRaises(cassandra.OperationTimedOut, cluster.add_execution_profile, - 'profile_{0}'.format(i), + with pytest.raises(cassandra.OperationTimedOut): + cluster.add_execution_profile('profile_{0}'.format(i), node2, pool_wait_timeout=sys.float_info.min) break except AssertionError: end = time.time() - self.assertAlmostEqual(start, end, 1) + assert start == pytest.approx(end, abs=1e-1) else: raise Exception("add_execution_profile didn't timeout after {0} retries".format(max_retry_count)) @@ -1112,7 +1115,7 @@ def test_execute_query_timeout(self): # default is passed down default_profile = cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT] rs = session.execute(query) - self.assertEqual(rs.response_future.timeout, default_profile.request_timeout) + assert rs.response_future.timeout == default_profile.request_timeout # tiny timeout times out as expected tmp_profile = copy(default_profile) @@ -1122,14 +1125,14 @@ def test_execute_query_timeout(self): for _ in range(max_retry_count): start = time.time() try: - with self.assertRaises(cassandra.OperationTimedOut): + with pytest.raises(cassandra.OperationTimedOut): session.execute(query, execution_profile=tmp_profile) break except: import traceback traceback.print_exc() end = time.time() - self.assertAlmostEqual(start, end, 1) + assert start == pytest.approx(end, abs=1e-1) else: raise Exception("session.execute didn't time out in {0} tries".format(max_retry_count)) @@ -1227,41 +1230,37 @@ def test_compact_option(self): "({i}, 'a{i}{i}', {i}{i}, {i}{i}, textAsBlob('b{i}{i}'))".format(i=i)) nc_results = nc_session.execute("SELECT * FROM compact_table") - self.assertEqual( - set(nc_results.current_rows), - {(1, u'a1', 11, 11, 'b1'), - (1, u'a11', 11, 11, 'b11'), - (2, u'a2', 22, 22, 'b2'), - (2, u'a22', 22, 22, 'b22'), - (3, u'a3', 33, 33, 'b3'), - (3, u'a33', 33, 33, 'b33'), - (4, u'a4', 44, 44, 'b4'), - (4, u'a44', 44, 44, 'b44')}) + assert set(nc_results.current_rows) == {(1, u'a1', 11, 11, 'b1'), + (1, u'a11', 11, 11, 'b11'), + (2, u'a2', 22, 22, 'b2'), + (2, u'a22', 22, 22, 'b22'), + (3, u'a3', 33, 33, 'b3'), + (3, u'a33', 33, 33, 'b33'), + (4, u'a4', 44, 44, 'b4'), + (4, u'a44', 44, 44, 'b44')} results = session.execute("SELECT * FROM compact_table") - self.assertEqual( - set(results.current_rows), - {(1, 11, 11), - (2, 22, 22), - (3, 33, 33), - (4, 44, 44)}) + assert set(results.current_rows) == {(1, 11, 11), + (2, 22, 22), + (3, 33, 33), + (4, 44, 44)} def _assert_replica_queried(self, trace, only_replicas=True): queried_hosts = set() for row in trace.events: queried_hosts.add(row.source) if only_replicas: - self.assertEqual(len(queried_hosts), 1, "The hosts queried where {}".format(queried_hosts)) + assert len(queried_hosts) == 1, "The hosts queried where {}".format(queried_hosts) else: - self.assertGreater(len(queried_hosts), 1, "The host queried was {}".format(queried_hosts)) + assert len(queried_hosts) > 1, "The host queried was {}".format(queried_hosts) return queried_hosts def _check_trace(self, trace): - self.assertIsNotNone(trace.request_type) - self.assertIsNotNone(trace.duration) - self.assertIsNotNone(trace.started_at) - self.assertIsNotNone(trace.coordinator) - self.assertIsNotNone(trace.events) + assert trace.request_type is not None + assert trace.duration is not None + assert trace.started_at is not None + assert trace.coordinator is not None + assert trace.events is not None class LocalHostAdressTranslator(AddressTranslator): @@ -1293,7 +1292,7 @@ def test_address_translator_basic(self): lh_ad = LocalHostAdressTranslator({'127.0.0.1': '127.0.0.1', '127.0.0.2': '127.0.0.1', '127.0.0.3': '127.0.0.1'}) c = TestCluster(address_translator=lh_ad) c.connect() - self.assertEqual(len(c.metadata.all_hosts()), 1) + assert len(c.metadata.all_hosts()) == 1 c.shutdown() def test_address_translator_with_mixed_nodes(self): @@ -1314,7 +1313,7 @@ def test_address_translator_with_mixed_nodes(self): c = TestCluster(address_translator=lh_ad) c.connect() for host in c.metadata.all_hosts(): - self.assertEqual(adder_map.get(host.address), host.broadcast_address) + assert adder_map.get(host.address) == host.broadcast_address c.shutdown() @local @@ -1338,8 +1337,8 @@ def test_no_connect(self): @test_category configuration """ with TestCluster() as cluster: - self.assertFalse(cluster.is_shutdown) - self.assertTrue(cluster.is_shutdown) + assert not cluster.is_shutdown + assert cluster.is_shutdown def test_simple_nested(self): """ @@ -1353,11 +1352,11 @@ def test_simple_nested(self): """ with TestCluster(**self.cluster_kwargs) as cluster: with cluster.connect() as session: - self.assertFalse(cluster.is_shutdown) - self.assertFalse(session.is_shutdown) - self.assertTrue(session.execute('select release_version from system.local').one()) - self.assertTrue(session.is_shutdown) - self.assertTrue(cluster.is_shutdown) + assert not cluster.is_shutdown + assert not session.is_shutdown + assert session.execute('select release_version from system.local').one() + assert session.is_shutdown + assert cluster.is_shutdown def test_cluster_no_session(self): """ @@ -1371,11 +1370,11 @@ def test_cluster_no_session(self): """ with TestCluster(**self.cluster_kwargs) as cluster: session = cluster.connect() - self.assertFalse(cluster.is_shutdown) - self.assertFalse(session.is_shutdown) - self.assertTrue(session.execute('select release_version from system.local').one()) - self.assertTrue(session.is_shutdown) - self.assertTrue(cluster.is_shutdown) + assert not cluster.is_shutdown + assert not session.is_shutdown + assert session.execute('select release_version from system.local').one() + assert session.is_shutdown + assert cluster.is_shutdown def test_session_no_cluster(self): """ @@ -1390,18 +1389,18 @@ def test_session_no_cluster(self): cluster = TestCluster(**self.cluster_kwargs) unmanaged_session = cluster.connect() with cluster.connect() as session: - self.assertFalse(cluster.is_shutdown) - self.assertFalse(session.is_shutdown) - self.assertFalse(unmanaged_session.is_shutdown) - self.assertTrue(session.execute('select release_version from system.local').one()) - self.assertTrue(session.is_shutdown) - self.assertFalse(cluster.is_shutdown) - self.assertFalse(unmanaged_session.is_shutdown) + assert not cluster.is_shutdown + assert not session.is_shutdown + assert not unmanaged_session.is_shutdown + assert session.execute('select release_version from system.local').one() + assert session.is_shutdown + assert not cluster.is_shutdown + assert not unmanaged_session.is_shutdown unmanaged_session.shutdown() - self.assertTrue(unmanaged_session.is_shutdown) - self.assertFalse(cluster.is_shutdown) + assert unmanaged_session.is_shutdown + assert not cluster.is_shutdown cluster.shutdown() - self.assertTrue(cluster.is_shutdown) + assert cluster.is_shutdown class HostStateTest(unittest.TestCase): @@ -1424,7 +1423,7 @@ def test_down_event_with_active_connection(self): cluster.on_down(random_host, False) for _ in range(10): new_host = cluster.metadata.all_hosts()[0] - self.assertTrue(new_host.is_up, "Host was not up on iteration {0}".format(_)) + assert new_host.is_up, "Host was not up on iteration {0}".format(_) time.sleep(.01) pool = session._pools.get(random_host) @@ -1437,7 +1436,7 @@ def test_down_event_with_active_connection(self): was_marked_down = True break time.sleep(.01) - self.assertTrue(was_marked_down) + assert was_marked_down @local @@ -1475,7 +1474,7 @@ def test_prepare_on_ignored_hosts(self): # the length of mock_calls will vary, but all should use the unignored # address for c in cluster.connection_factory.mock_calls: - self.assertEqual(unignored_address, c.args[0].address) + assert unignored_address == c.args[0].address cluster.shutdown() @@ -1497,10 +1496,10 @@ def test_invalid_protocol_version_beta_option(self): cluster = TestCluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=False) try: - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): cluster.connect() except Exception as e: - self.fail("Unexpected error encountered {0}".format(e.message)) + pytest.fail("Unexpected error encountered {0}".format(e.message)) @protocolv6 def test_valid_protocol_version_beta_options_connect(self): @@ -1515,8 +1514,8 @@ def test_valid_protocol_version_beta_options_connect(self): """ cluster = Cluster(protocol_version=cassandra.ProtocolVersion.V6, allow_beta_protocol_version=True) session = cluster.connect() - self.assertEqual(cluster.protocol_version, cassandra.ProtocolVersion.V6) - self.assertTrue(session.execute("select release_version from system.local").one()) + assert cluster.protocol_version == cassandra.ProtocolVersion.V6 + assert session.execute("select release_version from system.local").one() cluster.shutdown() @@ -1535,10 +1534,10 @@ def test_deprecation_warnings_legacy_parameters(self): with warnings.catch_warnings(record=True) as w: TestCluster(load_balancing_policy=RoundRobinPolicy()) logging.info(w) - self.assertGreaterEqual(len(w), 1) - self.assertTrue(any(["Legacy execution parameters will be removed in 4.0. " + assert len(w) >= 1 + assert any(["Legacy execution parameters will be removed in 4.0. " "Consider using execution profiles." in - str(wa.message) for wa in w])) + str(wa.message) for wa in w]) def test_deprecation_warnings_meta_refreshed(self): """ @@ -1555,9 +1554,9 @@ def test_deprecation_warnings_meta_refreshed(self): cluster = TestCluster() cluster.set_meta_refresh_enabled(True) logging.info(w) - self.assertGreaterEqual(len(w), 1) - self.assertTrue(any(["Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0." in - str(wa.message) for wa in w])) + assert len(w) >= 1 + assert any(["Cluster.set_meta_refresh_enabled is deprecated and will be removed in 4.0." in + str(wa.message) for wa in w]) def test_deprecation_warning_default_consistency_level(self): """ @@ -1574,6 +1573,6 @@ def test_deprecation_warning_default_consistency_level(self): cluster = TestCluster() session = cluster.connect() session.default_consistency_level = ConsistencyLevel.ONE - self.assertGreaterEqual(len(w), 1) - self.assertTrue(any(["Setting the consistency level at the session level will be removed in 4.0" in - str(wa.message) for wa in w])) + assert len(w) >= 1 + assert any(["Setting the consistency level at the session level will be removed in 4.0" in + str(wa.message) for wa in w]) diff --git a/tests/integration/standard/test_concurrent.py b/tests/integration/standard/test_concurrent.py index ba891b4bd0..e4bd379dee 100644 --- a/tests/integration/standard/test_concurrent.py +++ b/tests/integration/standard/test_concurrent.py @@ -25,6 +25,7 @@ from tests.integration import use_singledc, PROTOCOL_VERSION, TestCluster import unittest +import pytest log = logging.getLogger(__name__) @@ -90,10 +91,10 @@ def execute_concurrent_base(self, test_fn, validate_fn, zip_args=True): results = \ test_fn(self.session, list(zip(statements, parameters))) if zip_args else \ test_fn(self.session, statement, parameters) - self.assertEqual(num_statements, len(results)) + assert num_statements == len(results) for success, result in results: - self.assertTrue(success) - self.assertFalse(result) + assert success + assert not result # read statement = SimpleStatement( @@ -108,12 +109,12 @@ def execute_concurrent_base(self, test_fn, validate_fn, zip_args=True): validate_fn(num_statements, results) def execute_concurrent_valiate_tuple(self, num_statements, results): - self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, [(i,)]) for i in range(num_statements)], results) + assert num_statements == len(results) + assert [(True, [(i,)]) for i in range(num_statements)] == results def execute_concurrent_valiate_dict(self, num_statements, results): - self.assertEqual(num_statements, len(results)) - self.assertEqual([(True, [{"v":i}]) for i in range(num_statements)], results) + assert num_statements == len(results) + assert [(True, [{"v":i}]) for i in range(num_statements)] == results def test_execute_concurrent(self): self.execute_concurrent_base(self.execute_concurrent_helper, \ @@ -155,14 +156,14 @@ def test_execute_concurrent_with_args_generator(self): results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) for success, result in results: - self.assertTrue(success) - self.assertFalse(result) + assert success + assert not result results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) for result in results: - self.assertTrue(isinstance(result, ExecutionResult)) - self.assertTrue(result.success) - self.assertFalse(result.result_or_exc) + assert isinstance(result, ExecutionResult) + assert result.success + assert not result.result_or_exc # read statement = SimpleStatement( @@ -174,8 +175,9 @@ def test_execute_concurrent_with_args_generator(self): for i in range(num_statements): result = next(results) - self.assertEqual((True, [(i,)]), result) - self.assertRaises(StopIteration, next, results) + assert (True, [(i,)]) == result + with pytest.raises(StopIteration): + next(results) def test_execute_concurrent_paged_result(self): if PROTOCOL_VERSION < 2: @@ -190,10 +192,10 @@ def test_execute_concurrent_paged_result(self): parameters = [(i, i) for i in range(num_statements)] results = self.execute_concurrent_args_helper(self.session, statement, parameters) - self.assertEqual(num_statements, len(results)) + assert num_statements == len(results) for success, result in results: - self.assertTrue(success) - self.assertFalse(result) + assert success + assert not result # read statement = SimpleStatement( @@ -202,11 +204,11 @@ def test_execute_concurrent_paged_result(self): fetch_size=int(num_statements / 2)) results = self.execute_concurrent_args_helper(self.session, statement, [(num_statements,)]) - self.assertEqual(1, len(results)) - self.assertTrue(results[0][0]) + assert 1 == len(results) + assert results[0][0] result = results[0][1] - self.assertTrue(result.has_more_pages) - self.assertEqual(num_statements, sum(1 for _ in result)) + assert result.has_more_pages + assert num_statements == sum(1 for _ in result) def test_execute_concurrent_paged_result_generator(self): """ @@ -233,7 +235,7 @@ def test_execute_concurrent_paged_result_generator(self): parameters = [(i, i) for i in range(num_statements)] results = self.execute_concurrent_args_helper(self.session, statement, parameters, results_generator=True) - self.assertEqual(num_statements, sum(1 for _ in results)) + assert num_statements == sum(1 for _ in results) # read statement = SimpleStatement( @@ -250,7 +252,7 @@ def test_execute_concurrent_paged_result_generator(self): for _ in paged_result: found_results += 1 - self.assertEqual(found_results, num_statements) + assert found_results == num_statements def test_first_failure(self): statements = cycle(("INSERT INTO test3rf.test (k, v) VALUES (%s, %s)", )) @@ -259,9 +261,8 @@ def test_first_failure(self): # we'll get an error back from the server parameters[57] = ('efefef', 'awefawefawef') - self.assertRaises( - InvalidRequest, - execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True) + with pytest.raises(InvalidRequest): + execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=True) def test_first_failure_client_side(self): statement = SimpleStatement( @@ -273,9 +274,8 @@ def test_first_failure_client_side(self): # the driver will raise an error when binding the params parameters[57] = 1 - self.assertRaises( - TypeError, - execute_concurrent, self.session, list(zip(statements, parameters)), raise_on_first_error=True) + with pytest.raises(TypeError): + execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=True) def test_no_raise_on_first_failure(self): statement = SimpleStatement( @@ -290,11 +290,11 @@ def test_no_raise_on_first_failure(self): results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False) for i, (success, result) in enumerate(results): if i == 57: - self.assertFalse(success) - self.assertIsInstance(result, InvalidRequest) + assert not success + assert isinstance(result, InvalidRequest) else: - self.assertTrue(success) - self.assertFalse(result) + assert success + assert not result def test_no_raise_on_first_failure_client_side(self): statement = SimpleStatement( @@ -309,8 +309,8 @@ def test_no_raise_on_first_failure_client_side(self): results = execute_concurrent(self.session, list(zip(statements, parameters)), raise_on_first_error=False) for i, (success, result) in enumerate(results): if i == 57: - self.assertFalse(success) - self.assertIsInstance(result, TypeError) + assert not success + assert isinstance(result, TypeError) else: - self.assertTrue(success) - self.assertFalse(result) + assert success + assert not result diff --git a/tests/integration/standard/test_connection.py b/tests/integration/standard/test_connection.py index 4a5c23d6bf..630e5e6ba0 100644 --- a/tests/integration/standard/test_connection.py +++ b/tests/integration/standard/test_connection.py @@ -22,6 +22,7 @@ from threading import Thread, Event import time from unittest import SkipTest +import pytest from cassandra import ConsistencyLevel, OperationTimedOut, DependencyException from cassandra.cluster import NoHostAvailable, ConnectionShutdown, ExecutionProfile, EXEC_PROFILE_DEFAULT @@ -131,7 +132,7 @@ def test_heart_beat_timeout(self): host = "127.0.0.1:9042" node = get_node(1) initial_connections = self.fetch_connections(host, self.cluster) - self.assertNotEqual(len(initial_connections), 0) + assert len(initial_connections) != 0 self.cluster.register_listener(test_listener) # Pause the node try: @@ -141,7 +142,7 @@ def test_heart_beat_timeout(self): # Wait to seconds for the driver to be notified time.sleep(2) - self.assertTrue(test_listener.host_down) + assert test_listener.host_down # Resume paused node finally: node.resume() @@ -154,12 +155,12 @@ def test_heart_beat_timeout(self): current_host = str(rs._current_host) count += 1 time.sleep(.1) - self.assertLess(count, 100, "Never connected to the first node") + assert count < 100, "Never connected to the first node" new_connections = self.wait_for_connections(host, self.cluster) - self.assertFalse(test_listener.host_down) + assert not test_listener.host_down # Make sure underlying new connections don't match previous ones for connection in initial_connections: - self.assertFalse(connection in new_connections) + assert not connection in new_connections def fetch_connections(self, host, cluster): # Given a cluster object and host grab all connection associated with that host @@ -179,7 +180,7 @@ def wait_for_connections(self, host, cluster): if connections: return connections time.sleep(.1) - self.fail("No new connections found") + pytest.fail("No new connections found") def wait_for_no_connections(self, host, cluster): retry = 0 @@ -189,7 +190,7 @@ def wait_for_no_connections(self, host, cluster): if not connections: return time.sleep(.5) - self.fail("Connections never cleared") + pytest.fail("Connections never cleared") class ConnectionTests(object): @@ -396,10 +397,10 @@ def test_connect_timeout(self): conn.close() except Exception as e: end = time.time() - self.assertAlmostEqual(start, end, 1) + assert start == pytest.approx(end, abs=1e-1) exception_thrown = True break - self.assertTrue(exception_thrown) + assert exception_thrown def test_subclasses_share_loop(self): @@ -420,7 +421,7 @@ class C2(self.klass): self.addCleanup(clusterC1.shutdown) self.addCleanup(clusterC2.shutdown) - self.assertEqual(len(get_eventloop_threads(self.event_loop_name)), 1) + assert len(get_eventloop_threads(self.event_loop_name)) == 1 def get_eventloop_threads(name): diff --git a/tests/integration/standard/test_control_connection.py b/tests/integration/standard/test_control_connection.py index ea434c37c5..32cc468e64 100644 --- a/tests/integration/standard/test_control_connection.py +++ b/tests/integration/standard/test_control_connection.py @@ -72,7 +72,7 @@ def test_drop_keyspace(self): cc_id_pre_drop = id(self.cluster.control_connection._connection) self.session.execute("DROP KEYSPACE keyspacetodrop") cc_id_post_drop = id(self.cluster.control_connection._connection) - self.assertEqual(cc_id_post_drop, cc_id_pre_drop) + assert cc_id_post_drop == cc_id_pre_drop def test_get_control_connection_host(self): """ @@ -86,19 +86,19 @@ def test_get_control_connection_host(self): """ host = self.cluster.get_control_connection_host() - self.assertEqual(host, None) + assert host == None self.session = self.cluster.connect() cc_host = self.cluster.control_connection._connection.host host = self.cluster.get_control_connection_host() - self.assertEqual(host.address, cc_host) - self.assertEqual(host.is_up, True) + assert host.address == cc_host + assert host.is_up == True # reconnect and make sure that the new host is reflected correctly self.cluster.control_connection._reconnect() new_host = self.cluster.get_control_connection_host() - self.assertNotEqual(host, new_host) + assert host != new_host # TODO: enable after https://github.com/scylladb/python-driver/issues/121 is fixed @unittest.skip('Fails on scylla due to the broadcast_rpc_port is None') @@ -113,17 +113,17 @@ def test_control_connection_port_discovery(self): self.cluster = TestCluster() host = self.cluster.get_control_connection_host() - self.assertEqual(host, None) + assert host == None self.session = self.cluster.connect() cc_endpoint = self.cluster.control_connection._connection.endpoint host = self.cluster.get_control_connection_host() - self.assertEqual(host.endpoint, cc_endpoint) - self.assertEqual(host.is_up, True) + assert host.endpoint == cc_endpoint + assert host.is_up == True hosts = self.cluster.metadata.all_hosts() - self.assertEqual(3, len(hosts)) + assert 3 == len(hosts) for host in hosts: - self.assertEqual(9042, host.broadcast_rpc_port) - self.assertEqual(7000, host.broadcast_port) + assert 9042 == host.broadcast_rpc_port + assert 7000 == host.broadcast_port diff --git a/tests/integration/standard/test_custom_cluster.py b/tests/integration/standard/test_custom_cluster.py index c1eabbfd1f..4eb62e43bc 100644 --- a/tests/integration/standard/test_custom_cluster.py +++ b/tests/integration/standard/test_custom_cluster.py @@ -17,6 +17,7 @@ from tests.util import wait_until, wait_until_not_raised import unittest +import pytest def setup_module(): @@ -44,7 +45,7 @@ def test_connection_honor_cluster_port(self): All hosts should be marked as up and we should be able to execute queries on it. """ cluster = TestCluster() - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): cluster.connect() # should fail on port 9042 cluster = TestCluster(port=9046) @@ -52,5 +53,5 @@ def test_connection_honor_cluster_port(self): wait_until(lambda: len(cluster.metadata.all_hosts()) == 3, 1, 5) for host in cluster.metadata.all_hosts(): - self.assertTrue(host.is_up) + assert host.is_up session.execute("select * from system.local where key='local'", host=host) diff --git a/tests/integration/standard/test_custom_payload.py b/tests/integration/standard/test_custom_payload.py index 92372972c6..fc58081070 100644 --- a/tests/integration/standard/test_custom_payload.py +++ b/tests/integration/standard/test_custom_payload.py @@ -19,6 +19,7 @@ from tests.integration import (use_singledc, PROTOCOL_VERSION, local, TestCluster, requires_custom_payload) +import pytest def setup_module(): @@ -148,7 +149,7 @@ def validate_various_custom_payloads(self, statement): # Add one custom payload to this is too many key value pairs and should fail custom_payload[str(65535)] = b'x' - with self.assertRaises(ValueError): + with pytest.raises(ValueError): self.execute_async_validate_custom_payload(statement=statement, custom_payload=custom_payload) def execute_async_validate_custom_payload(self, statement, custom_payload): @@ -164,4 +165,4 @@ def execute_async_validate_custom_payload(self, statement, custom_payload): response_future = self.session.execute_async(statement, custom_payload=custom_payload) response_future.result() returned_custom_payload = response_future.custom_payload - self.assertEqual(custom_payload, returned_custom_payload) + assert custom_payload == returned_custom_payload diff --git a/tests/integration/standard/test_custom_protocol_handler.py b/tests/integration/standard/test_custom_protocol_handler.py index 26d3f5fe35..a9025bba97 100644 --- a/tests/integration/standard/test_custom_protocol_handler.py +++ b/tests/integration/standard/test_custom_protocol_handler.py @@ -28,6 +28,7 @@ import uuid from unittest import mock +import pytest def setup_module(): @@ -71,20 +72,20 @@ def test_custom_raw_uuid_row_results(self): result = session.execute("SELECT schema_version FROM system.local WHERE key='local'") uuid_type = result.one()[0] - self.assertEqual(type(uuid_type), uuid.UUID) + assert type(uuid_type) == uuid.UUID # use our custom protocol handlder session.client_protocol_handler = CustomTestRawRowType result_set = session.execute("SELECT schema_version FROM system.local WHERE key='local'") raw_value = result_set.one()[0] - self.assertTrue(isinstance(raw_value, bytes)) - self.assertEqual(len(raw_value), 16) + assert isinstance(raw_value, bytes) + assert len(raw_value) == 16 # Ensure that we get normal uuid back when we re-connect session.client_protocol_handler = ProtocolHandler result_set = session.execute("SELECT schema_version FROM system.local WHERE key='local'") uuid_type = result_set.one()[0] - self.assertEqual(type(uuid_type), uuid.UUID) + assert type(uuid_type) == uuid.UUID cluster.shutdown() def test_custom_raw_row_results_all_types(self): @@ -115,9 +116,9 @@ def test_custom_raw_row_results_all_types(self): params = get_all_primitive_params(0) results = session.execute("SELECT {0} FROM alltypes WHERE primkey=0".format(columns_string)).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # Ensure we have covered the various primitive types - self.assertEqual(len(CustomResultMessageTracked.checked_rev_row_set), len(PRIMITIVE_DATATYPES)-1) + assert len(CustomResultMessageTracked.checked_rev_row_set) == len(PRIMITIVE_DATATYPES)-1 cluster.shutdown() @unittest.expectedFailure @@ -145,10 +146,9 @@ def test_protocol_divergence_v5_fail_by_continuous_paging(self): continuous_paging_options=continuous_paging_options) # This should raise NoHostAvailable because continuous paging is not supported under ProtocolVersion.DSE_V1 - with self.assertRaises(NoHostAvailable) as context: + with pytest.raises(NoHostAvailable) as context: future.result() - self.assertIn("Continuous paging may only be used with protocol version ProtocolVersion.DSE_V1 or higher", - str(context.exception)) + assert "Continuous paging may only be used with protocol version ProtocolVersion.DSE_V1 or higher" in str(context.value) cluster.shutdown() @@ -219,7 +219,7 @@ def _protocol_divergence_fail_by_flag_uses_int(self, version, uses_int_query_fla response = future.result() # This means the flag are not handled as they are meant by the server if uses_int=False - self.assertEqual(response.has_more_pages, uses_int_query_flag) + assert response.has_more_pages == uses_int_query_flag execute_with_long_wait_retry(session, SimpleStatement("TRUNCATE test3rf.test")) cluster.shutdown() diff --git a/tests/integration/standard/test_cython_protocol_handlers.py b/tests/integration/standard/test_cython_protocol_handlers.py index 9e85edb914..9ec16cecc7 100644 --- a/tests/integration/standard/test_cython_protocol_handlers.py +++ b/tests/integration/standard/test_cython_protocol_handlers.py @@ -47,14 +47,14 @@ def test_cython_parser(self): """ Test Cython-based parser that returns a list of tuples """ - verify_iterator_data(self.assertEqual, get_data(ProtocolHandler)) + verify_iterator_data(get_data(ProtocolHandler)) @cythontest def test_cython_lazy_parser(self): """ Test Cython-based parser that returns an iterator of tuples """ - verify_iterator_data(self.assertEqual, get_data(LazyProtocolHandler)) + verify_iterator_data(get_data(LazyProtocolHandler)) @numpytest def test_cython_lazy_results_paged(self): @@ -69,12 +69,12 @@ def test_cython_lazy_results_paged(self): session.client_protocol_handler = LazyProtocolHandler session.default_fetch_size = 2 - self.assertLess(session.default_fetch_size, self.N_ITEMS) + assert session.default_fetch_size < self.N_ITEMS results = session.execute("SELECT * FROM test_table") - self.assertTrue(results.has_more_pages) - self.assertEqual(verify_iterator_data(self.assertEqual, results), self.N_ITEMS) # make sure we see all rows + assert results.has_more_pages + assert verify_iterator_data(results) == self.N_ITEMS # make sure we see all rows cluster.shutdown() @@ -86,7 +86,7 @@ def test_numpy_parser(self): """ # arrays = { 'a': arr1, 'b': arr2, ... } result = get_data(NumpyProtocolHandler) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages self._verify_numpy_page(result[0]) @notprotocolv1 @@ -105,23 +105,23 @@ def test_numpy_results_paged(self): expected_pages = (self.N_ITEMS + session.default_fetch_size - 1) // session.default_fetch_size - self.assertLess(session.default_fetch_size, self.N_ITEMS) + assert session.default_fetch_size < self.N_ITEMS results = session.execute("SELECT * FROM test_table") - self.assertTrue(results.has_more_pages) + assert results.has_more_pages for count, page in enumerate(results, 1): - self.assertIsInstance(page, dict) + assert isinstance(page, dict) for colname, arr in page.items(): if count <= expected_pages: - self.assertGreater(len(arr), 0, "page count: %d" % (count,)) - self.assertLessEqual(len(arr), session.default_fetch_size) + assert len(arr) > 0, "page count: %d" % (count,) + assert len(arr) <= session.default_fetch_size else: # we get one extra item out of this iteration because of the way NumpyParser returns results # The last page is returned as a dict with zero-length arrays - self.assertEqual(len(arr), 0) - self.assertEqual(self._verify_numpy_page(page), len(arr)) - self.assertEqual(count, expected_pages + 1) # see note about extra 'page' above + assert len(arr) == 0 + assert self._verify_numpy_page(page) == len(arr) + assert count == expected_pages + 1 # see note about extra 'page' above cluster.shutdown() @@ -136,8 +136,8 @@ def test_cython_numpy_are_installed_valid(self): @test_category configuration """ if VERIFY_CYTHON: - self.assertTrue(HAVE_CYTHON) - self.assertTrue(HAVE_NUMPY) + assert HAVE_CYTHON + assert HAVE_NUMPY def _verify_numpy_page(self, page): colnames = self.colnames @@ -146,7 +146,7 @@ def _verify_numpy_page(self, page): arr = page[colname] self.match_dtype(datatype, arr.dtype) - return verify_iterator_data(self.assertEqual, arrays_to_list_of_tuples(page, colnames)) + return verify_iterator_data(arrays_to_list_of_tuples(page, colnames)) def match_dtype(self, datatype, dtype): """Match a string cqltype (e.g. 'int' or 'blob') with a numpy dtype""" @@ -161,11 +161,11 @@ def match_dtype(self, datatype, dtype): elif datatype == 'double': self.match_dtype_props(dtype, 'f', 8) else: - self.assertEqual(dtype.kind, 'O', msg=(dtype, datatype)) + assert dtype.kind == 'O', (dtype, datatype) def match_dtype_props(self, dtype, kind, size, signed=None): - self.assertEqual(dtype.kind, kind, msg=dtype) - self.assertEqual(dtype.itemsize, size, msg=dtype) + assert dtype.kind == kind, dtype + assert dtype.itemsize == size, dtype def arrays_to_list_of_tuples(arrays, colnames): @@ -192,7 +192,7 @@ def get_data(protocol_handler): return results -def verify_iterator_data(assertEqual, results): +def verify_iterator_data(results): """ Check the result of get_data() when this is a list or iterator of tuples @@ -200,10 +200,9 @@ def verify_iterator_data(assertEqual, results): count = 0 for count, result in enumerate(results, 1): params = get_all_primitive_params(result[0]) - assertEqual(len(params), len(result), - msg="Not the right number of columns?") + assert len(params) == len(result), "Not the right number of columns?" for expected, actual in zip(params, result): - assertEqual(actual, expected) + assert actual == expected return count @@ -250,11 +249,15 @@ def test_null_types(self): # because None and `masked` have different identity and equals semantics if isinstance(col_array, MaskedArray): had_masked = True - [self.assertIsNot(col_array[i], masked) for i in mapped_index[:begin_unset]] - [self.assertIs(col_array[i], masked) for i in mapped_index[begin_unset:]] + for i in mapped_index[:begin_unset]: + assert col_array[i] is not masked + for i in mapped_index[begin_unset:]: + assert col_array[i] is masked else: had_none = True - [self.assertIsNotNone(col_array[i]) for i in mapped_index[:begin_unset]] - [self.assertIsNone(col_array[i]) for i in mapped_index[begin_unset:]] - self.assertTrue(had_masked) - self.assertTrue(had_none) + for i in mapped_index[:begin_unset]: + assert col_array[i] is not None + for i in mapped_index[begin_unset:]: + assert col_array[i] is None + assert had_masked + assert had_none diff --git a/tests/integration/standard/test_metadata.py b/tests/integration/standard/test_metadata.py index 8d677030f9..eaaf9ad3a9 100644 --- a/tests/integration/standard/test_metadata.py +++ b/tests/integration/standard/test_metadata.py @@ -39,12 +39,12 @@ BasicExistingKeyspaceUnitTestCase, drop_keyspace_shutdown_cluster, CASSANDRA_VERSION, greaterthanorequalcass30, lessthancass30, local, get_supported_protocol_versions, greaterthancass20, - greaterthancass21, assert_startswith, greaterthanorequalcass40, + greaterthancass21, greaterthanorequalcass40, lessthancass40, TestCluster, requires_java_udf, requires_composite_type, requires_collection_indexes, SCYLLA_VERSION, xfail_scylla, xfail_scylla_version_lt) -from tests.util import wait_until +from tests.util import wait_until, assertRegex, assertDictEqual, assertListEqual, assert_startswith_diff log = logging.getLogger(__name__) @@ -70,24 +70,24 @@ def test_host_addresses(self): """ # All nodes should have the broadcast_address, rpc_address and host_id set for host in self.cluster.metadata.all_hosts(): - self.assertIsNotNone(host.broadcast_address) - self.assertIsNotNone(host.broadcast_rpc_address) - self.assertIsNotNone(host.host_id) + assert host.broadcast_address is not None + assert host.broadcast_rpc_address is not None + assert host.host_id is not None if CASSANDRA_VERSION >= Version('4-a'): - self.assertIsNotNone(host.broadcast_port) - self.assertIsNotNone(host.broadcast_rpc_port) + assert host.broadcast_port is not None + assert host.broadcast_rpc_port is not None con = self.cluster.control_connection.get_connections()[0] local_host = con.host # The control connection node should have the listen address set. listen_addrs = [host.listen_address for host in self.cluster.metadata.all_hosts()] - self.assertTrue(local_host in listen_addrs) + assert local_host in listen_addrs # The control connection node should have the broadcast_rpc_address set. rpc_addrs = [host.broadcast_rpc_address for host in self.cluster.metadata.all_hosts()] - self.assertTrue(local_host in rpc_addrs) + assert local_host in rpc_addrs @unittest.skipUnless( os.getenv('MAPPED_CASSANDRA_VERSION', None) is not None, @@ -104,7 +104,7 @@ def test_host_release_version(self): @test_category metadata """ for host in self.cluster.metadata.all_hosts(): - assert_startswith(host.release_version, CASSANDRA_VERSION.base_version) + assert host.release_version.startswith(CASSANDRA_VERSION.base_version) @@ -133,7 +133,7 @@ def test_bad_contact_point(self): # verify the un-existing host was filtered for host in self.cluster.metadata.all_hosts(): - self.assertNotEqual(host.endpoint.address, '126.0.0.186') + assert host.endpoint.address != '126.0.0.186' class SchemaMetadataTests(BasicSegregatedKeyspaceUnitTestCase): @@ -153,18 +153,18 @@ def test_schema_metadata_disable(self): # Validate metadata is missing where appropriate no_schema = TestCluster(schema_metadata_enabled=False) no_schema_session = no_schema.connect() - self.assertEqual(len(no_schema.metadata.keyspaces), 0) - self.assertEqual(no_schema.metadata.export_schema_as_string(), '') + assert len(no_schema.metadata.keyspaces) == 0 + assert no_schema.metadata.export_schema_as_string() == '' no_token = TestCluster(token_metadata_enabled=False) no_token_session = no_token.connect() - self.assertEqual(len(no_token.metadata.token_map.token_to_host_owner), 0) + assert len(no_token.metadata.token_map.token_to_host_owner) == 0 # Do a simple query to ensure queries are working query = "SELECT * FROM system.local WHERE key='local'" no_schema_rs = no_schema_session.execute(query) no_token_rs = no_token_session.execute(query) - self.assertIsNotNone(no_schema_rs.one()) - self.assertIsNotNone(no_token_rs.one()) + assert no_schema_rs.one() is not None + assert no_token_rs.one() is not None no_schema.shutdown() no_token.shutdown() @@ -201,7 +201,7 @@ def make_create_statement(self, partition_cols, clustering_cols=None, other_cols def check_create_statement(self, tablemeta, original): recreate = tablemeta.as_cql_query(formatted=False) - self.assertEqual(original, recreate[:len(original)]) + assert original == recreate[:len(original)] execute_until_pass(self.session, "DROP TABLE {0}.{1}".format(self.keyspace_name, self.function_table_name)) execute_until_pass(self.session, recreate) @@ -221,24 +221,24 @@ def test_basic_table_meta_properties(self): self.cluster.refresh_schema_metadata() meta = self.cluster.metadata - self.assertNotEqual(meta.cluster_name, None) - self.assertTrue(self.keyspace_name in meta.keyspaces) + assert meta.cluster_name != None + assert self.keyspace_name in meta.keyspaces ksmeta = meta.keyspaces[self.keyspace_name] - self.assertEqual(ksmeta.name, self.keyspace_name) - self.assertTrue(ksmeta.durable_writes) - self.assertEqual(ksmeta.replication_strategy.name, 'SimpleStrategy') - self.assertEqual(ksmeta.replication_strategy.replication_factor, 1) + assert ksmeta.name == self.keyspace_name + assert ksmeta.durable_writes + assert ksmeta.replication_strategy.name == 'SimpleStrategy' + assert ksmeta.replication_strategy.replication_factor == 1 - self.assertTrue(self.function_table_name in ksmeta.tables) + assert self.function_table_name in ksmeta.tables tablemeta = ksmeta.tables[self.function_table_name] - self.assertEqual(tablemeta.keyspace_name, ksmeta.name) - self.assertEqual(tablemeta.name, self.function_table_name) - self.assertEqual(tablemeta.name, self.function_table_name) + assert tablemeta.keyspace_name == ksmeta.name + assert tablemeta.name == self.function_table_name + assert tablemeta.name == self.function_table_name - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([], tablemeta.clustering_key) - self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [] == tablemeta.clustering_key + assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) cc = self.cluster.control_connection._connection parser = get_schema_parser( @@ -250,7 +250,7 @@ def test_basic_table_meta_properties(self): ) for option in tablemeta.options: - self.assertIn(option, parser.recognized_table_options) + assert option in parser.recognized_table_options self.check_create_statement(tablemeta, create_statement) @@ -260,9 +260,9 @@ def test_compound_primary_keys(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [u'b'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -272,9 +272,9 @@ def test_compound_primary_keys_protected(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'Aa'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'Bb'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'Aa', u'Bb', u'Cc'], sorted(tablemeta.columns.keys())) + assert [u'Aa'] == [c.name for c in tablemeta.partition_key] + assert [u'Bb'] == [c.name for c in tablemeta.clustering_key] + assert [u'Aa', u'Bb', u'Cc'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -284,11 +284,9 @@ def test_compound_primary_keys_more_columns(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'b', u'c'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual( - [u'a', u'b', u'c', u'd', u'e', u'f'], - sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c', u'd', u'e', u'f'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -297,9 +295,9 @@ def test_composite_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([], tablemeta.clustering_key) - self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert [] == tablemeta.clustering_key + assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -309,9 +307,9 @@ def test_composite_in_compound_primary_key(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'a', u'b', u'c', u'd', u'e'], sorted(tablemeta.columns.keys())) + assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert [u'c'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c', u'd', u'e'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -321,9 +319,9 @@ def test_compound_primary_keys_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'b'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [u'b'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -345,9 +343,9 @@ def test_cluster_column_ordering_reversed_metadata(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() b_column = tablemeta.columns['b'] - self.assertFalse(b_column.is_reversed) + assert not b_column.is_reversed c_column = tablemeta.columns['c'] - self.assertTrue(c_column.is_reversed) + assert c_column.is_reversed def test_compound_primary_keys_more_columns_compact(self): create_statement = self.make_create_statement(["a"], ["b", "c"], ["d"]) @@ -355,9 +353,9 @@ def test_compound_primary_keys_more_columns_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'b', u'c'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [u'b', u'c'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -366,9 +364,9 @@ def test_composite_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([], tablemeta.clustering_key) - self.assertEqual([u'a', u'b', u'c'], sorted(tablemeta.columns.keys())) + assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert [] == tablemeta.clustering_key + assert [u'a', u'b', u'c'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -378,9 +376,9 @@ def test_composite_in_compound_primary_key_compact(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a', u'b'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([u'c'], [c.name for c in tablemeta.clustering_key]) - self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + assert [u'a', u'b'] == [c.name for c in tablemeta.partition_key] + assert [u'c'] == [c.name for c in tablemeta.clustering_key] + assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) self.check_create_statement(tablemeta, create_statement) @@ -393,18 +391,18 @@ def test_cql_compatibility(self): self.session.execute(create_statement) tablemeta = self.get_table_metadata() - self.assertEqual([u'a'], [c.name for c in tablemeta.partition_key]) - self.assertEqual([], tablemeta.clustering_key) - self.assertEqual([u'a', u'b', u'c', u'd'], sorted(tablemeta.columns.keys())) + assert [u'a'] == [c.name for c in tablemeta.partition_key] + assert [] == tablemeta.clustering_key + assert [u'a', u'b', u'c', u'd'] == sorted(tablemeta.columns.keys()) - self.assertTrue(tablemeta.is_cql_compatible) + assert tablemeta.is_cql_compatible # It will be cql compatible after CASSANDRA-10857 # since compact storage is being dropped tablemeta.clustering_key = ["foo", "bar"] tablemeta.columns["foo"] = None tablemeta.columns["bar"] = None - self.assertTrue(tablemeta.is_cql_compatible) + assert tablemeta.is_cql_compatible def test_compound_primary_keys_ordering(self): create_statement = self.make_create_statement(["a"], ["b"], ["c"]) @@ -493,15 +491,15 @@ def test_indexes(self): statements = tablemeta.export_as_string().strip() statements = [s.strip() for s in statements.split(';')] statements = list(filter(bool, statements)) - self.assertEqual(3, len(statements)) - self.assertIn(d_index, statements) - self.assertIn(e_index, statements) + assert 3 == len(statements) + assert d_index in statements + assert e_index in statements # make sure indexes are included in KeyspaceMetadata.export_as_string() ksmeta = self.cluster.metadata.keyspaces[self.keyspace_name] statement = ksmeta.export_as_string() - self.assertIn('CREATE INDEX d_index', statement) - self.assertIn('CREATE INDEX e_index', statement) + assert 'CREATE INDEX d_index' in statement + assert 'CREATE INDEX e_index' in statement @greaterthancass21 @requires_collection_indexes @@ -514,7 +512,7 @@ def test_collection_indexes(self): % (self.keyspace_name, self.function_table_name)) tablemeta = self.get_table_metadata() - self.assertIn('(keys(b))', tablemeta.export_as_string()) + assert '(keys(b))' in tablemeta.export_as_string() self.session.execute("DROP INDEX %s.index1" % (self.keyspace_name,)) self.session.execute("CREATE INDEX index2 ON %s.%s (b)" @@ -522,7 +520,7 @@ def test_collection_indexes(self): tablemeta = self.get_table_metadata() target = ' (b)' if CASSANDRA_VERSION < Version("3.0") else 'values(b))' # explicit values in C* 3+ - self.assertIn(target, tablemeta.export_as_string()) + assert target in tablemeta.export_as_string() # test full indexes on frozen collections, if available if CASSANDRA_VERSION >= Version("2.1.3"): @@ -533,7 +531,7 @@ def test_collection_indexes(self): % (self.keyspace_name, self.function_table_name)) tablemeta = self.get_table_metadata() - self.assertIn('(full(b))', tablemeta.export_as_string()) + assert '(full(b))' in tablemeta.export_as_string() def test_compression_disabled(self): create_statement = self.make_create_statement(["a"], ["b"], ["c"]) @@ -543,7 +541,7 @@ def test_compression_disabled(self): expected = "compression = {'enabled': 'false'}" if SCYLLA_VERSION is not None or CASSANDRA_VERSION < Version("3.0"): expected = "compression = {}" - self.assertIn(expected, tablemeta.export_as_string()) + assert expected in tablemeta.export_as_string() def test_non_size_tiered_compaction(self): """ @@ -564,12 +562,12 @@ def test_non_size_tiered_compaction(self): table_meta = self.get_table_metadata() cql = table_meta.export_as_string() - self.assertIn("'tombstone_threshold': '0.3'", cql) - self.assertIn("LeveledCompactionStrategy", cql) + assert "'tombstone_threshold': '0.3'" in cql + assert "LeveledCompactionStrategy" in cql # formerly legacy options; reintroduced in 4.0 if CASSANDRA_VERSION < Version('4.0-a'): - self.assertNotIn("min_threshold", cql) - self.assertNotIn("max_threshold", cql) + assert "min_threshold" not in cql + assert "max_threshold" not in cql @requires_java_udf def test_refresh_schema_metadata(self): @@ -593,20 +591,20 @@ def test_refresh_schema_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + assert "new_keyspace" not in cluster2.metadata.keyspaces # Cluster metadata modification self.session.execute("CREATE KEYSPACE new_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}") - self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + assert "new_keyspace" not in cluster2.metadata.keyspaces cluster2.refresh_schema_metadata() - self.assertIn("new_keyspace", cluster2.metadata.keyspaces) + assert "new_keyspace" in cluster2.metadata.keyspaces # Keyspace metadata modification self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) - self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_schema_metadata() - self.assertFalse(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes # Table metadata modification table_name = "test" @@ -614,16 +612,16 @@ def test_refresh_schema_metadata(self): cluster2.refresh_schema_metadata() self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns cluster2.refresh_schema_metadata() - self.assertIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns if PROTOCOL_VERSION >= 3: # UDT metadata modification self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_schema_metadata() - self.assertIn("user", cluster2.metadata.keyspaces[self.keyspace_name].user_types) + assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types if PROTOCOL_VERSION >= 4: # UDF metadata modification @@ -632,9 +630,9 @@ def test_refresh_schema_metadata(self): RETURNS int LANGUAGE java AS 'return key+val;';""".format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} cluster2.refresh_schema_metadata() - self.assertIn("sum_int(int,int)", cluster2.metadata.keyspaces[self.keyspace_name].functions) + assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions # UDA metadata modification self.session.execute("""CREATE AGGREGATE {0}.sum_agg(int) @@ -643,16 +641,16 @@ def test_refresh_schema_metadata(self): INITCOND 0""" .format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} cluster2.refresh_schema_metadata() - self.assertIn("sum_agg(int)", cluster2.metadata.keyspaces[self.keyspace_name].aggregates) + assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates # Cluster metadata modification self.session.execute("DROP KEYSPACE new_keyspace") - self.assertIn("new_keyspace", cluster2.metadata.keyspaces) + assert "new_keyspace" in cluster2.metadata.keyspaces cluster2.refresh_schema_metadata() - self.assertNotIn("new_keyspace", cluster2.metadata.keyspaces) + assert "new_keyspace" not in cluster2.metadata.keyspaces cluster2.shutdown() @@ -676,11 +674,11 @@ def test_refresh_keyspace_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes self.session.execute("ALTER KEYSPACE {0} WITH durable_writes = false".format(self.keyspace_name)) - self.assertTrue(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + assert cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.refresh_keyspace_metadata(self.keyspace_name) - self.assertFalse(cluster2.metadata.keyspaces[self.keyspace_name].durable_writes) + assert not cluster2.metadata.keyspaces[self.keyspace_name].durable_writes cluster2.shutdown() @@ -707,12 +705,12 @@ def test_refresh_table_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns self.session.execute("ALTER TABLE {0}.{1} ADD c double".format(self.keyspace_name, table_name)) - self.assertNotIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + assert "c" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns cluster2.refresh_table_metadata(self.keyspace_name, table_name) - self.assertIn("c", cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns) + assert "c" in cluster2.metadata.keyspaces[self.keyspace_name].tables[table_name].columns cluster2.shutdown() @@ -742,38 +740,38 @@ def test_refresh_metadata_for_mv(self): cluster2.connect() try: - self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views self.session.execute("CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT a, b FROM {0}.{1} " "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)" .format(self.keyspace_name, self.function_table_name)) - self.assertNotIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv1" not in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views cluster2.refresh_table_metadata(self.keyspace_name, "mv1") - self.assertIn("mv1", cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv1" in cluster2.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views finally: cluster2.shutdown() original_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] - self.assertIs(original_meta, self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1']) + assert original_meta is self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] self.cluster.refresh_materialized_view_metadata(self.keyspace_name, 'mv1') current_meta = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] - self.assertIsNot(current_meta, original_meta) - self.assertIsNot(original_meta, self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1']) - self.assertEqual(original_meta.as_cql_query(), current_meta.as_cql_query()) + assert current_meta is not original_meta + assert original_meta is not self.session.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views['mv1'] + assert original_meta.as_cql_query() == current_meta.as_cql_query() cluster3 = TestCluster(schema_event_refresh_window=-1) cluster3.connect() try: - self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv2 AS SELECT a, b FROM {0}.{1} " "WHERE a IS NOT NULL AND b IS NOT NULL PRIMARY KEY (a, b)".format( self.keyspace_name, self.function_table_name) ) - self.assertNotIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv2" not in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views cluster3.refresh_materialized_view_metadata(self.keyspace_name, 'mv2') - self.assertIn("mv2", cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv2" in cluster3.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views finally: cluster3.shutdown() @@ -800,12 +798,12 @@ def test_refresh_user_type_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} self.session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].user_types, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].user_types == {} cluster2.refresh_user_type_metadata(self.keyspace_name, "user") - self.assertIn("user", cluster2.metadata.keyspaces[self.keyspace_name].user_types) + assert "user" in cluster2.metadata.keyspaces[self.keyspace_name].user_types cluster2.shutdown() @@ -827,21 +825,21 @@ def test_refresh_user_type_metadata_proto_2(self): for protocol_version in (1, 2): cluster = TestCluster() session = cluster.connect() - self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} session.execute("CREATE TYPE {0}.user (age int, name text)".format(self.keyspace_name)) - self.assertIn("user", cluster.metadata.keyspaces[self.keyspace_name].user_types) - self.assertIn("age", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) - self.assertIn("name", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + assert "user" in cluster.metadata.keyspaces[self.keyspace_name].user_types + assert "age" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names + assert "name" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names session.execute("ALTER TYPE {0}.user ADD flag boolean".format(self.keyspace_name)) - self.assertIn("flag", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + assert "flag" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names session.execute("ALTER TYPE {0}.user RENAME flag TO something".format(self.keyspace_name)) - self.assertIn("something", cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names) + assert "something" in cluster.metadata.keyspaces[self.keyspace_name].user_types["user"].field_names session.execute("DROP TYPE {0}.user".format(self.keyspace_name)) - self.assertEqual(cluster.metadata.keyspaces[self.keyspace_name].user_types, {}) + assert cluster.metadata.keyspaces[self.keyspace_name].user_types == {} cluster.shutdown() @requires_java_udf @@ -869,15 +867,15 @@ def test_refresh_user_function_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int LANGUAGE java AS ' return key + val;';""".format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].functions, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].functions == {} cluster2.refresh_user_function_metadata(self.keyspace_name, UserFunctionDescriptor("sum_int", ["int", "int"])) - self.assertIn("sum_int(int,int)", cluster2.metadata.keyspaces[self.keyspace_name].functions) + assert "sum_int(int,int)" in cluster2.metadata.keyspaces[self.keyspace_name].functions cluster2.shutdown() @@ -906,7 +904,7 @@ def test_refresh_user_aggregate_metadata(self): cluster2 = TestCluster(schema_event_refresh_window=-1) cluster2.connect() - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} self.session.execute("""CREATE FUNCTION {0}.sum_int(key int, val int) RETURNS NULL ON NULL INPUT RETURNS int @@ -918,9 +916,9 @@ def test_refresh_user_aggregate_metadata(self): INITCOND 0""" .format(self.keyspace_name)) - self.assertEqual(cluster2.metadata.keyspaces[self.keyspace_name].aggregates, {}) + assert cluster2.metadata.keyspaces[self.keyspace_name].aggregates == {} cluster2.refresh_user_aggregate_metadata(self.keyspace_name, UserAggregateDescriptor("sum_agg", ["int"])) - self.assertIn("sum_agg(int)", cluster2.metadata.keyspaces[self.keyspace_name].aggregates) + assert "sum_agg(int)" in cluster2.metadata.keyspaces[self.keyspace_name].aggregates cluster2.shutdown() @@ -944,19 +942,19 @@ def test_multiple_indices(self): self.session.execute("CREATE INDEX index_2 ON {0}.{1}(keys(b))".format(self.keyspace_name, self.function_table_name)) indices = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].indexes - self.assertEqual(len(indices), 2) + assert len(indices) == 2 index_1 = indices["index_1"] index_2 = indices['index_2'] - self.assertEqual(index_1.table_name, "test_multiple_indices") - self.assertEqual(index_1.name, "index_1") - self.assertEqual(index_1.kind, "COMPOSITES") - self.assertEqual(index_1.index_options["target"], "values(b)") - self.assertEqual(index_1.keyspace_name, "schemametadatatests") - self.assertEqual(index_2.table_name, "test_multiple_indices") - self.assertEqual(index_2.name, "index_2") - self.assertEqual(index_2.kind, "COMPOSITES") - self.assertEqual(index_2.index_options["target"], "keys(b)") - self.assertEqual(index_2.keyspace_name, "schemametadatatests") + assert index_1.table_name == "test_multiple_indices" + assert index_1.name == "index_1" + assert index_1.kind == "COMPOSITES" + assert index_1.index_options["target"] == "values(b)" + assert index_1.keyspace_name == "schemametadatatests" + assert index_2.table_name == "test_multiple_indices" + assert index_2.name == "index_2" + assert index_2.kind == "COMPOSITES" + assert index_2.index_options["target"] == "keys(b)" + assert index_2.keyspace_name == "schemametadatatests" @greaterthanorequalcass30 def test_table_extensions(self): @@ -990,17 +988,17 @@ def after_table_cql(cls, table_meta, ext_key, ext_blob): class Ext1(Ext0): name = t + '##' - self.assertIn(Ext0.name, _RegisteredExtensionType._extension_registry) - self.assertIn(Ext1.name, _RegisteredExtensionType._extension_registry) + assert Ext0.name in _RegisteredExtensionType._extension_registry + assert Ext1.name in _RegisteredExtensionType._extension_registry # There will bee the RLAC extension here. - self.assertEqual(len(_RegisteredExtensionType._extension_registry), 3) + assert len(_RegisteredExtensionType._extension_registry) == 3 self.cluster.refresh_table_metadata(ks, t) table_meta = ks_meta.tables[t] view_meta = table_meta.views[v] - self.assertEqual(table_meta.export_as_string(), original_table_cql) - self.assertEqual(view_meta.export_as_string(), original_view_cql) + assert table_meta.export_as_string() == original_table_cql + assert view_meta.export_as_string() == original_view_cql update_t = s.prepare('UPDATE system_schema.tables SET extensions=? WHERE keyspace_name=? AND table_name=?') # for blob type coercing update_v = s.prepare('UPDATE system_schema.views SET extensions=? WHERE keyspace_name=? AND view_name=?') @@ -1014,17 +1012,17 @@ class Ext1(Ext0): table_meta = ks_meta.tables[t] view_meta = table_meta.views[v] - self.assertIn(Ext0.name, table_meta.extensions) + assert Ext0.name in table_meta.extensions new_cql = table_meta.export_as_string() - self.assertNotEqual(new_cql, original_table_cql) - self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) - self.assertNotIn(Ext1.name, new_cql) + assert new_cql != original_table_cql + assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert Ext1.name not in new_cql - self.assertIn(Ext0.name, view_meta.extensions) + assert Ext0.name in view_meta.extensions new_cql = view_meta.export_as_string() - self.assertNotEqual(new_cql, original_view_cql) - self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) - self.assertNotIn(Ext1.name, new_cql) + assert new_cql != original_view_cql + assert Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert Ext1.name not in new_cql # extensions registered, one present # -------------------------------------- @@ -1037,19 +1035,19 @@ class Ext1(Ext0): table_meta = ks_meta.tables[t] view_meta = table_meta.views[v] - self.assertIn(Ext0.name, table_meta.extensions) - self.assertIn(Ext1.name, table_meta.extensions) + assert Ext0.name in table_meta.extensions + assert Ext1.name in table_meta.extensions new_cql = table_meta.export_as_string() - self.assertNotEqual(new_cql, original_table_cql) - self.assertIn(Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]), new_cql) - self.assertIn(Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + assert new_cql != original_table_cql + assert Ext0.after_table_cql(table_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert Ext1.after_table_cql(table_meta, Ext1.name, ext_map[Ext1.name]) in new_cql - self.assertIn(Ext0.name, view_meta.extensions) - self.assertIn(Ext1.name, view_meta.extensions) + assert Ext0.name in view_meta.extensions + assert Ext1.name in view_meta.extensions new_cql = view_meta.export_as_string() - self.assertNotEqual(new_cql, original_view_cql) - self.assertIn(Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]), new_cql) - self.assertIn(Ext1.after_table_cql(view_meta, Ext1.name, ext_map[Ext1.name]), new_cql) + assert new_cql != original_view_cql + assert Ext0.after_table_cql(view_meta, Ext0.name, ext_map[Ext0.name]) in new_cql + assert Ext1.after_table_cql(view_meta, Ext1.name, ext_map[Ext1.name]) in new_cql def test_metadata_pagination(self): self.cluster.refresh_schema_metadata() @@ -1059,7 +1057,7 @@ def test_metadata_pagination(self): self.cluster.schema_metadata_page_size = 5 self.cluster.refresh_schema_metadata() - self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].tables), 12) + assert len(self.cluster.metadata.keyspaces[self.keyspace_name].tables) == 12 def test_metadata_pagination_keyspaces(self): """ @@ -1084,7 +1082,7 @@ def test_metadata_pagination_keyspaces(self): after_ks_num = len(self.cluster.metadata.keyspaces) - self.assertEqual(before_ks_num, after_ks_num) + assert before_ks_num == after_ks_num class TestCodeCoverage(unittest.TestCase): @@ -1097,7 +1095,7 @@ def test_export_schema(self): cluster = TestCluster() cluster.connect() - self.assertIsInstance(cluster.metadata.export_schema_as_string(), str) + assert isinstance(cluster.metadata.export_schema_as_string(), str) cluster.shutdown() def test_export_keyspace_schema(self): @@ -1110,27 +1108,10 @@ def test_export_keyspace_schema(self): for keyspace in cluster.metadata.keyspaces: keyspace_metadata = cluster.metadata.keyspaces[keyspace] - self.assertIsInstance(keyspace_metadata.export_as_string(), str) - self.assertIsInstance(keyspace_metadata.as_cql_query(), str) + assert isinstance(keyspace_metadata.export_as_string(), str) + assert isinstance(keyspace_metadata.as_cql_query(), str) cluster.shutdown() - def assert_equal_diff(self, received, expected): - if received != expected: - diff_string = '\n'.join(difflib.unified_diff(expected.split('\n'), - received.split('\n'), - 'EXPECTED', 'RECEIVED', - lineterm='')) - self.fail(diff_string) - - def assert_startswith_diff(self, received, prefix): - if not received.startswith(prefix): - prefix_lines = prefix.split('\n') - diff_string = '\n'.join(difflib.unified_diff(prefix_lines, - received.split('\n')[:len(prefix_lines)], - 'EXPECTED', 'RECEIVED', - lineterm='')) - self.fail(diff_string) - @greaterthancass20 def test_export_keyspace_schema_udts(self): """ @@ -1195,7 +1176,7 @@ def test_export_keyspace_schema_udts(self): user text PRIMARY KEY, addresses map>""" - self.assert_startswith_diff(cluster.metadata.keyspaces['export_udts'].export_as_string(), expected_prefix) + assert_startswith_diff(cluster.metadata.keyspaces['export_udts'].export_as_string(), expected_prefix) table_meta = cluster.metadata.keyspaces['export_udts'].tables['users'] @@ -1203,7 +1184,7 @@ def test_export_keyspace_schema_udts(self): user text PRIMARY KEY, addresses map>""" - self.assert_startswith_diff(table_meta.export_as_string(), expected_prefix) + assert_startswith_diff(table_meta.export_as_string(), expected_prefix) cluster.shutdown() @@ -1244,15 +1225,15 @@ def test_case_sensitivity(self): ksmeta = cluster.metadata.keyspaces[ksname] schema = ksmeta.export_as_string() - self.assertIn('CREATE KEYSPACE "AnInterestingKeyspace"', schema) - self.assertIn('CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"', schema) - self.assertIn('"A" int', schema) - self.assertIn('"B" int', schema) - self.assertIn('"MyColumn" int', schema) - self.assertIn('PRIMARY KEY (k, "A")', schema) - self.assertIn('WITH CLUSTERING ORDER BY ("A" DESC)', schema) - self.assertIn('CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")', schema) - self.assertIn('CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")', schema) + assert 'CREATE KEYSPACE "AnInterestingKeyspace"' in schema + assert 'CREATE TABLE "AnInterestingKeyspace"."AnInterestingTable"' in schema + assert '"A" int' in schema + assert '"B" int' in schema + assert '"MyColumn" int' in schema + assert 'PRIMARY KEY (k, "A")' in schema + assert 'WITH CLUSTERING ORDER BY ("A" DESC)' in schema + assert 'CREATE INDEX myindex ON "AnInterestingKeyspace"."AnInterestingTable" ("MyColumn")' in schema + assert 'CREATE INDEX "AnotherIndex" ON "AnInterestingKeyspace"."AnInterestingTable" ("B")' in schema cluster.shutdown() def test_already_exists_exceptions(self): @@ -1269,13 +1250,15 @@ def test_already_exists_exceptions(self): ddl = ''' CREATE KEYSPACE %s WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'}''' - self.assertRaises(AlreadyExists, session.execute, ddl % ksname) + with pytest.raises(AlreadyExists): + session.execute(ddl % ksname) ddl = ''' CREATE TABLE %s.%s ( k int PRIMARY KEY, v int )''' - self.assertRaises(AlreadyExists, session.execute, ddl % (ksname, cfname)) + with pytest.raises(AlreadyExists): + session.execute(ddl % (ksname, cfname)) cluster.shutdown() @local @@ -1288,14 +1271,14 @@ def test_replicas(self): raise unittest.SkipTest('the murmur3 extension is not available') cluster = TestCluster() - self.assertEqual(cluster.metadata.get_replicas('test3rf', 'key'), []) + assert cluster.metadata.get_replicas('test3rf', 'key') == [] cluster.connect('test3rf') - self.assertNotEqual(list(cluster.metadata.get_replicas('test3rf', b'key')), []) + assert list(cluster.metadata.get_replicas('test3rf', b'key')) != [] host = list(cluster.metadata.get_replicas('test3rf', b'key'))[0] - self.assertEqual(host.datacenter, 'dc1') - self.assertEqual(host.rack, 'r1') + assert host.datacenter == 'dc1' + assert host.rack == 'r1' cluster.shutdown() def test_token_map(self): @@ -1310,12 +1293,12 @@ def test_token_map(self): get_replicas = cluster.metadata.token_map.get_replicas for ksname in ('test1rf', 'test2rf', 'test3rf'): - self.assertNotEqual(list(get_replicas(ksname, ring[0])), []) + assert list(get_replicas(ksname, ring[0])) != [] for i, token in enumerate(ring): - self.assertEqual(set(get_replicas('test3rf', token)), set(owners)) - self.assertEqual(set(get_replicas('test2rf', token)), set([owners[i], owners[(i + 1) % 3]])) - self.assertEqual(set(get_replicas('test1rf', token)), set([owners[i]])) + assert set(get_replicas('test3rf', token)) == set(owners) + assert set(get_replicas('test2rf', token)) == set([owners[i], owners[(i + 1) % 3]]) + assert set(get_replicas('test1rf', token)) == set([owners[i]]) cluster.shutdown() @@ -1330,8 +1313,8 @@ def test_token(self): cluster = TestCluster() cluster.connect() tmap = cluster.metadata.token_map - self.assertTrue(issubclass(tmap.token_class, Token)) - self.assertEqual(expected_node_count, len(tmap.ring)) + assert issubclass(tmap.token_class, Token) + assert expected_node_count == len(tmap.ring) cluster.shutdown() @@ -1363,8 +1346,7 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, for stmt in stmts: if "SELECT now() FROM system.local WHERE key='local'" in stmt: continue - if "USING TIMEOUT 2000ms" not in stmt: - self.fail(f"query `{stmt}` does not contain `USING TIMEOUT 2000ms`") + assert "USING TIMEOUT 2000ms" in stmt, f"query `{stmt}` does not contain `USING TIMEOUT 2000ms`" class KeyspaceAlterMetadata(unittest.TestCase): @@ -1398,13 +1380,13 @@ def test_keyspace_alter(self): self.session.execute('CREATE TABLE %s.d (d INT PRIMARY KEY)' % name) original_keyspace_meta = self.cluster.metadata.keyspaces[name] - self.assertEqual(original_keyspace_meta.durable_writes, True) - self.assertEqual(len(original_keyspace_meta.tables), 1) + assert original_keyspace_meta.durable_writes == True + assert len(original_keyspace_meta.tables) == 1 self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % name) new_keyspace_meta = self.cluster.metadata.keyspaces[name] - self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) - self.assertEqual(new_keyspace_meta.durable_writes, False) + assert original_keyspace_meta != new_keyspace_meta + assert new_keyspace_meta.durable_writes == False class IndexMapTests(unittest.TestCase): @@ -1451,10 +1433,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - self.assertNotIn('a_idx', ks_meta.indexes) - self.assertNotIn('b_idx', ks_meta.indexes) - self.assertNotIn('a_idx', table_meta.indexes) - self.assertNotIn('b_idx', table_meta.indexes) + assert 'a_idx' not in ks_meta.indexes + assert 'b_idx' not in ks_meta.indexes + assert 'a_idx' not in table_meta.indexes + assert 'b_idx' not in table_meta.indexes self.session.execute("CREATE INDEX a_idx ON %s (a)" % self.table_name) self.session.execute("ALTER TABLE %s ADD b int" % self.table_name) @@ -1462,10 +1444,10 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - self.assertIsInstance(ks_meta.indexes['a_idx'], IndexMetadata) - self.assertIsInstance(ks_meta.indexes['b_idx'], IndexMetadata) - self.assertIsInstance(table_meta.indexes['a_idx'], IndexMetadata) - self.assertIsInstance(table_meta.indexes['b_idx'], IndexMetadata) + assert isinstance(ks_meta.indexes['a_idx'], IndexMetadata) + assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) + assert isinstance(table_meta.indexes['a_idx'], IndexMetadata) + assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) # both indexes updated when index dropped self.session.execute("DROP INDEX a_idx") @@ -1475,17 +1457,17 @@ def test_index_updates(self): ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - self.assertNotIn('a_idx', ks_meta.indexes) - self.assertIsInstance(ks_meta.indexes['b_idx'], IndexMetadata) - self.assertNotIn('a_idx', table_meta.indexes) - self.assertIsInstance(table_meta.indexes['b_idx'], IndexMetadata) + assert 'a_idx' not in ks_meta.indexes + assert isinstance(ks_meta.indexes['b_idx'], IndexMetadata) + assert 'a_idx' not in table_meta.indexes + assert isinstance(table_meta.indexes['b_idx'], IndexMetadata) # keyspace index updated when table dropped self.drop_basic_table() ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertNotIn(self.table_name, ks_meta.tables) - self.assertNotIn('a_idx', ks_meta.indexes) - self.assertNotIn('b_idx', ks_meta.indexes) + assert self.table_name not in ks_meta.tables + assert 'a_idx' not in ks_meta.indexes + assert 'b_idx' not in ks_meta.indexes def test_index_follows_alter(self): self.create_basic_table() @@ -1494,15 +1476,15 @@ def test_index_follows_alter(self): self.session.execute("CREATE INDEX %s ON %s (a)" % (idx, self.table_name)) ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] table_meta = ks_meta.tables[self.table_name] - self.assertIsInstance(ks_meta.indexes[idx], IndexMetadata) - self.assertIsInstance(table_meta.indexes[idx], IndexMetadata) + assert isinstance(ks_meta.indexes[idx], IndexMetadata) + assert isinstance(table_meta.indexes[idx], IndexMetadata) self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) old_meta = ks_meta ks_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIsNot(ks_meta, old_meta) + assert ks_meta is not old_meta table_meta = ks_meta.tables[self.table_name] - self.assertIsInstance(ks_meta.indexes[idx], IndexMetadata) - self.assertIsInstance(table_meta.indexes[idx], IndexMetadata) + assert isinstance(ks_meta.indexes[idx], IndexMetadata) + assert isinstance(table_meta.indexes[idx], IndexMetadata) self.drop_basic_table() @requires_java_udf @@ -1551,18 +1533,18 @@ def __init__(self, test_case, meta_class, element_meta, **function_kwargs): def __enter__(self): tc = self.test_case expected_meta = self.meta_class(**self.function_kwargs) - tc.assertNotIn(expected_meta.signature, self.element_meta) + assert expected_meta.signature not in self.element_meta tc.session.execute(expected_meta.as_cql_query()) - tc.assertIn(expected_meta.signature, self.element_meta) + assert expected_meta.signature in self.element_meta generated_meta = self.element_meta[expected_meta.signature] - self.test_case.assertEqual(generated_meta.as_cql_query(), expected_meta.as_cql_query()) + assert generated_meta.as_cql_query() == expected_meta.as_cql_query() return self def __exit__(self, exc_type, exc_val, exc_tb): tc = self.test_case tc.session.execute("DROP %s %s.%s" % (self.meta_class.__name__, tc.keyspace_name, self.signature)) - tc.assertNotIn(self.signature, self.element_meta) + assert self.signature not in self.element_meta @property def signature(self): @@ -1616,7 +1598,7 @@ def test_functions_after_udt(self): @test_category function """ - self.assertNotIn(self.function_name, self.keyspace_function_meta) + assert self.function_name not in self.keyspace_function_meta udt_name = 'udtx' self.session.execute("CREATE TYPE %s (x int)" % udt_name) @@ -1626,8 +1608,8 @@ def test_functions_after_udt(self): keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() type_idx = keyspace_cql.rfind("CREATE TYPE") func_idx = keyspace_cql.find("CREATE FUNCTION") - self.assertNotIn(-1, (type_idx, func_idx), "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql) - self.assertGreater(func_idx, type_idx) + assert -1 not in (type_idx, func_idx), "TYPE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert func_idx > type_idx def test_function_same_name_diff_types(self): """ @@ -1647,16 +1629,16 @@ def test_function_same_name_diff_types(self): with self.VerifiedFunction(self, **kwargs): # another function: same name, different type sig. - self.assertGreater(len(kwargs['argument_types']), 1) - self.assertGreater(len(kwargs['argument_names']), 1) + assert len(kwargs['argument_types']) > 1 + assert len(kwargs['argument_names']) > 1 kwargs['argument_types'] = kwargs['argument_types'][:1] kwargs['argument_names'] = kwargs['argument_names'][:1] # Ensure they are surfaced separately with self.VerifiedFunction(self, **kwargs): functions = [f for f in self.keyspace_function_meta.values() if f.name == self.function_name] - self.assertEqual(len(functions), 2) - self.assertNotEqual(functions[0].argument_types, functions[1].argument_types) + assert len(functions) == 2 + assert functions[0].argument_types != functions[1].argument_types def test_function_no_parameters(self): """ @@ -1677,7 +1659,7 @@ def test_function_no_parameters(self): with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*%s\(\) .*' % kwargs['name']) + assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*%s\(\) .*' % kwargs['name']) def test_functions_follow_keyspace_alter(self): """ @@ -1701,8 +1683,8 @@ def test_functions_follow_keyspace_alter(self): # After keyspace alter ensure that we maintain function equality. try: new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) - self.assertIs(original_keyspace_meta.functions, new_keyspace_meta.functions) + assert original_keyspace_meta != new_keyspace_meta + assert original_keyspace_meta.functions is new_keyspace_meta.functions finally: self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) @@ -1725,12 +1707,12 @@ def test_function_cql_called_on_null(self): kwargs['called_on_null_input'] = True with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*') + assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) CALLED ON NULL INPUT RETURNS .*') kwargs['called_on_null_input'] = False with self.VerifiedFunction(self, **kwargs) as vf: fn_meta = self.keyspace_function_meta[vf.signature] - self.assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*') + assertRegex(fn_meta.as_cql_query(), r'CREATE FUNCTION.*\) RETURNS NULL ON NULL INPUT RETURNS .*') @requires_java_udf @@ -1793,7 +1775,7 @@ def test_return_type_meta(self): """ with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond='1')) as va: - self.assertEqual(self.keyspace_aggregate_meta[va.signature].return_type, 'int') + assert self.keyspace_aggregate_meta[va.signature].return_type == 'int' def test_init_cond(self): """ @@ -1822,16 +1804,15 @@ def test_init_cond(self): cql_init = encoder.cql_encode_all_types(init_cond) with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('sum_int', 'int', init_cond=cql_init)) as va: sum_res = s.execute("SELECT %s(v) AS sum FROM t" % va.function_kwargs['name']).one().sum - self.assertEqual(sum_res, int(init_cond) + sum(expected_values)) + assert sum_res == int(init_cond) + sum(expected_values) # list for init_cond in ([], ['1', '2']): cql_init = encoder.cql_encode_all_types(init_cond) with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('extend_list', 'list', init_cond=cql_init)) as va: list_res = s.execute("SELECT %s(v) AS list_res FROM t" % va.function_kwargs['name']).one().list_res - self.assertListEqual(list_res[:len(init_cond)], init_cond) - self.assertEqual(set(i for i in list_res[len(init_cond):]), - set(str(i) for i in expected_values)) + assertListEqual(list_res[:len(init_cond)], init_cond) + assert set(i for i in list_res[len(init_cond):]) == set(str(i) for i in expected_values) # map expected_map_values = dict((i, i) for i in expected_values) @@ -1840,9 +1821,9 @@ def test_init_cond(self): cql_init = encoder.cql_encode_all_types(init_cond) with self.VerifiedAggregate(self, **self.make_aggregate_kwargs('update_map', 'map', init_cond=cql_init)) as va: map_res = s.execute("SELECT %s(v) AS map_res FROM t" % va.function_kwargs['name']).one().map_res - self.assertLessEqual(expected_map_values.items(), map_res.items()) + assert expected_map_values.items() <= map_res.items() init_not_updated = dict((k, init_cond[k]) for k in set(init_cond) - expected_key_set) - self.assertLessEqual(init_not_updated.items(), map_res.items()) + assert init_not_updated.items() <= map_res.items() c.shutdown() def test_aggregates_after_functions(self): @@ -1864,8 +1845,8 @@ def test_aggregates_after_functions(self): keyspace_cql = self.cluster.metadata.keyspaces[self.keyspace_name].export_as_string() func_idx = keyspace_cql.find("CREATE FUNCTION") aggregate_idx = keyspace_cql.rfind("CREATE AGGREGATE") - self.assertNotIn(-1, (aggregate_idx, func_idx), "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql) - self.assertGreater(aggregate_idx, func_idx) + assert -1 not in (aggregate_idx, func_idx), "AGGREGATE or FUNCTION not found in keyspace_cql: " + keyspace_cql + assert aggregate_idx > func_idx def test_same_name_diff_types(self): """ @@ -1886,8 +1867,8 @@ def test_same_name_diff_types(self): kwargs['argument_types'] = ['int', 'int'] with self.VerifiedAggregate(self, **kwargs): aggregates = [a for a in self.keyspace_aggregate_meta.values() if a.name == kwargs['name']] - self.assertEqual(len(aggregates), 2) - self.assertNotEqual(aggregates[0].argument_types, aggregates[1].argument_types) + assert len(aggregates) == 2 + assert aggregates[0].argument_types != aggregates[1].argument_types def test_aggregates_follow_keyspace_alter(self): """ @@ -1908,8 +1889,8 @@ def test_aggregates_follow_keyspace_alter(self): self.session.execute('ALTER KEYSPACE %s WITH durable_writes = false' % self.keyspace_name) try: new_keyspace_meta = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertNotEqual(original_keyspace_meta, new_keyspace_meta) - self.assertIs(original_keyspace_meta.aggregates, new_keyspace_meta.aggregates) + assert original_keyspace_meta != new_keyspace_meta + assert original_keyspace_meta.aggregates is new_keyspace_meta.aggregates finally: self.session.execute('ALTER KEYSPACE %s WITH durable_writes = true' % self.keyspace_name) @@ -1931,51 +1912,51 @@ def test_cql_optional_params(self): encoder = Encoder() # no initial condition, final func - self.assertIsNone(kwargs['initial_condition']) - self.assertIsNone(kwargs['final_func']) + assert kwargs['initial_condition'] is None + assert kwargs['final_func'] is None with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - self.assertIsNone(meta.initial_condition) - self.assertIsNone(meta.final_func) + assert meta.initial_condition is None + assert meta.final_func is None cql = meta.as_cql_query() - self.assertEqual(cql.find('INITCOND'), -1) - self.assertEqual(cql.find('FINALFUNC'), -1) + assert cql.find('INITCOND') == -1 + assert cql.find('FINALFUNC') == -1 # initial condition, no final func kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - self.assertEqual(meta.initial_condition, kwargs['initial_condition']) - self.assertIsNone(meta.final_func) + assert meta.initial_condition == kwargs['initial_condition'] + assert meta.final_func is None cql = meta.as_cql_query() search_string = "INITCOND %s" % kwargs['initial_condition'] - self.assertGreater(cql.find(search_string), 0, '"%s" search string not found in cql:\n%s' % (search_string, cql)) - self.assertEqual(cql.find('FINALFUNC'), -1) + assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) + assert cql.find('FINALFUNC') == -1 # no initial condition, final func kwargs['initial_condition'] = None kwargs['final_func'] = 'List_As_String' with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - self.assertIsNone(meta.initial_condition) - self.assertEqual(meta.final_func, kwargs['final_func']) + assert meta.initial_condition is None + assert meta.final_func == kwargs['final_func'] cql = meta.as_cql_query() - self.assertEqual(cql.find('INITCOND'), -1) + assert cql.find('INITCOND') == -1 search_string = 'FINALFUNC "%s"' % kwargs['final_func'] - self.assertGreater(cql.find(search_string), 0, '"%s" search string not found in cql:\n%s' % (search_string, cql)) + assert cql.find(search_string) > 0, '"%s" search string not found in cql:\n%s' % (search_string, cql) # both kwargs['initial_condition'] = encoder.cql_encode_all_types(['init', 'cond']) kwargs['final_func'] = 'List_As_String' with self.VerifiedAggregate(self, **kwargs) as va: meta = self.keyspace_aggregate_meta[va.signature] - self.assertEqual(meta.initial_condition, kwargs['initial_condition']) - self.assertEqual(meta.final_func, kwargs['final_func']) + assert meta.initial_condition == kwargs['initial_condition'] + assert meta.final_func == kwargs['final_func'] cql = meta.as_cql_query() init_cond_idx = cql.find("INITCOND %s" % kwargs['initial_condition']) final_func_idx = cql.find('FINALFUNC "%s"' % kwargs['final_func']) - self.assertNotIn(-1, (init_cond_idx, final_func_idx)) - self.assertGreater(init_cond_idx, final_func_idx) + assert -1 not in (init_cond_idx, final_func_idx) + assert init_cond_idx > final_func_idx class BadMetaTest(unittest.TestCase): @@ -2018,16 +1999,16 @@ def test_bad_keyspace(self): with patch.object(self.parser_class, '_build_keyspace_metadata_internal', side_effect=self.BadMetaException): self.cluster.refresh_keyspace_metadata(self.keyspace_name) m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() def test_bad_table(self): self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) with patch.object(self.parser_class, '_build_column_metadata', side_effect=self.BadMetaException): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() def test_bad_index(self): self.session.execute('CREATE TABLE %s (k int PRIMARY KEY, v int)' % self.function_name) @@ -2035,8 +2016,8 @@ def test_bad_index(self): with patch.object(self.parser_class, '_build_index_metadata', side_effect=self.BadMetaException): self.cluster.refresh_table_metadata(self.keyspace_name, self.function_name) m = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() @greaterthancass20 def test_bad_user_type(self): @@ -2044,8 +2025,8 @@ def test_bad_user_type(self): with patch.object(self.parser_class, '_build_user_type', side_effect=self.BadMetaException): self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() @greaterthancass21 @requires_java_udf @@ -2063,8 +2044,8 @@ def test_bad_user_function(self): with patch.object(self.parser_class, '_build_function', side_effect=self.BadMetaException): self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() @greaterthancass21 @requires_java_udf @@ -2082,8 +2063,8 @@ def test_bad_user_aggregate(self): with patch.object(self.parser_class, '_build_aggregate', side_effect=self.BadMetaException): self.cluster.refresh_schema_metadata() # presently do not capture these errors on udt direct refresh -- make sure it's contained during full refresh m = self.cluster.metadata.keyspaces[self.keyspace_name] - self.assertIs(m._exc_info[0], self.BadMetaException) - self.assertIn("/*\nWarning:", m.export_as_string()) + assert m._exc_info[0] is self.BadMetaException + assert "/*\nWarning:" in m.export_as_string() class DynamicCompositeTypeTest(BasicSharedKeyspaceUnitTestCase): @@ -2110,13 +2091,13 @@ def test_dct_alias(self): # Format can very slightly between versions, strip out whitespace for consistency sake table_text = dct_table.as_cql_query().replace(" ", "") dynamic_type_text = "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" - self.assertIn("c1'org.apache.cassandra.db.marshal.DynamicCompositeType(", table_text) + assert "c1'org.apache.cassandra.db.marshal.DynamicCompositeType(" in table_text # Types within in the composite can come out in random order, so grab the type definition and find each one type_definition_start = table_text.index("(", table_text.find(dynamic_type_text)) type_definition_end = table_text.index(")") type_definition_text = table_text[type_definition_start:type_definition_end] - self.assertIn("s=>org.apache.cassandra.db.marshal.UTF8Type", type_definition_text) - self.assertIn("i=>org.apache.cassandra.db.marshal.Int32Type", type_definition_text) + assert "s=>org.apache.cassandra.db.marshal.UTF8Type" in type_definition_text + assert "i=>org.apache.cassandra.db.marshal.Int32Type" in type_definition_text @greaterthanorequalcass30 @@ -2150,11 +2131,11 @@ def test_materialized_view_metadata_creation(self): @test_category metadata """ - self.assertIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].views) - self.assertIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].views + assert "mv1" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views - self.assertEqual(self.keyspace_name, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].keyspace_name) - self.assertEqual(self.function_table_name, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].base_table_name) + assert self.keyspace_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].keyspace_name + assert self.function_table_name == self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].base_table_name def test_materialized_view_metadata_alter(self): """ @@ -2171,10 +2152,10 @@ def test_materialized_view_metadata_alter(self): @test_category metadata """ - self.assertIn("SizeTieredCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"]) + assert "SizeTieredCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] self.session.execute("ALTER MATERIALIZED VIEW {0}.mv1 WITH compaction = {{ 'class' : 'LeveledCompactionStrategy' }}".format(self.keyspace_name)) - self.assertIn("LeveledCompactionStrategy", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"]) + assert "LeveledCompactionStrategy" in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views["mv1"].options["compaction"]["class"] def test_materialized_view_metadata_drop(self): """ @@ -2194,10 +2175,10 @@ def test_materialized_view_metadata_drop(self): self.session.execute("DROP MATERIALIZED VIEW {0}.mv1".format(self.keyspace_name)) - self.assertNotIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) - self.assertNotIn("mv1", self.cluster.metadata.keyspaces[self.keyspace_name].views) - self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) - self.assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) + assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views + assert "mv1" not in self.cluster.metadata.keyspaces[self.keyspace_name].views + assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].views) + assertDictEqual({}, self.cluster.metadata.keyspaces[self.keyspace_name].views) self.session.execute( "CREATE MATERIALIZED VIEW {0}.mv1 AS SELECT pk, c FROM {0}.{1} " @@ -2244,63 +2225,63 @@ def test_create_view_metadata(self): score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['monthlyhigh'] - self.assertIsNotNone(score_table.views["monthlyhigh"]) - self.assertIsNotNone(len(score_table.views), 1) + assert score_table.views["monthlyhigh"] is not None + assert len(score_table.views) is not None, 1 # Make sure user is a partition key, and not null - self.assertEqual(len(score_table.partition_key), 1) - self.assertIsNotNone(score_table.columns['user']) - self.assertTrue(score_table.columns['user'], score_table.partition_key[0]) + assert len(score_table.partition_key) == 1 + assert score_table.columns['user'] is not None + assert score_table.columns['user'], score_table.partition_key[0] # Validate clustering keys - self.assertEqual(len(score_table.clustering_key), 4) + assert len(score_table.clustering_key) == 4 - self.assertIsNotNone(score_table.columns['game']) - self.assertTrue(score_table.columns['game'], score_table.clustering_key[0]) + assert score_table.columns['game'] is not None + assert score_table.columns['game'], score_table.clustering_key[0] - self.assertIsNotNone(score_table.columns['year']) - self.assertTrue(score_table.columns['year'], score_table.clustering_key[1]) + assert score_table.columns['year'] is not None + assert score_table.columns['year'], score_table.clustering_key[1] - self.assertIsNotNone(score_table.columns['month']) - self.assertTrue(score_table.columns['month'], score_table.clustering_key[2]) + assert score_table.columns['month'] is not None + assert score_table.columns['month'], score_table.clustering_key[2] - self.assertIsNotNone(score_table.columns['day']) - self.assertTrue(score_table.columns['day'], score_table.clustering_key[3]) + assert score_table.columns['day'] is not None + assert score_table.columns['day'], score_table.clustering_key[3] - self.assertIsNotNone(score_table.columns['score']) + assert score_table.columns['score'] is not None # Validate basic mv information - self.assertEqual(mv.keyspace_name, self.keyspace_name) - self.assertEqual(mv.name, "monthlyhigh") - self.assertEqual(mv.base_table_name, "scores") - self.assertFalse(mv.include_all_columns) + assert mv.keyspace_name == self.keyspace_name + assert mv.name == "monthlyhigh" + assert mv.base_table_name == "scores" + assert not mv.include_all_columns # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEqual(len(mv_columns), 6) + assert len(mv_columns) == 6 game_column = mv_columns[0] - self.assertIsNotNone(game_column) - self.assertEqual(game_column.name, 'game') - self.assertEqual(game_column, mv.partition_key[0]) + assert game_column is not None + assert game_column.name == 'game' + assert game_column == mv.partition_key[0] year_column = mv_columns[1] - self.assertIsNotNone(year_column) - self.assertEqual(year_column.name, 'year') - self.assertEqual(year_column, mv.partition_key[1]) + assert year_column is not None + assert year_column.name == 'year' + assert year_column == mv.partition_key[1] month_column = mv_columns[2] - self.assertIsNotNone(month_column) - self.assertEqual(month_column.name, 'month') - self.assertEqual(month_column, mv.partition_key[2]) + assert month_column is not None + assert month_column.name == 'month' + assert month_column == mv.partition_key[2] def compare_columns(a, b, name): - self.assertEqual(a.name, name) - self.assertEqual(a.name, b.name) - self.assertEqual(a.table, b.table) - self.assertEqual(a.cql_type, b.cql_type) - self.assertEqual(a.is_static, b.is_static) - self.assertEqual(a.is_reversed, b.is_reversed) + assert a.name == name + assert a.name == b.name + assert a.table == b.table + assert a.cql_type == b.cql_type + assert a.is_static == b.is_static + assert a.is_reversed == b.is_reversed score_column = mv_columns[3] compare_columns(score_column, mv.clustering_key[0], 'score') @@ -2354,17 +2335,17 @@ def test_base_table_column_addition_mv(self): score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] - self.assertIsNotNone(score_table.views["monthlyhigh"]) - self.assertIsNotNone(score_table.views["alltimehigh"]) - self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 2) + assert score_table.views["monthlyhigh"] is not None + assert score_table.views["alltimehigh"] is not None + assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 insert_fouls = """ALTER TABLE {0}.scores ADD fouls INT""".format((self.keyspace_name)) self.session.execute(insert_fouls) - self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 2) + assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 2 score_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'] - self.assertIn("fouls", score_table.columns) + assert "fouls" in score_table.columns # This is a workaround for mv notifications being separate from base table schema responses. # This maybe fixed with future protocol changes @@ -2374,10 +2355,10 @@ def test_base_table_column_addition_mv(self): break time.sleep(.2) - self.assertIn("fouls", mv_alltime.columns) + assert "fouls" in mv_alltime.columns mv_alltime_fouls_comumn = self.cluster.metadata.keyspaces[self.keyspace_name].views["alltimehigh"].columns['fouls'] - self.assertEqual(mv_alltime_fouls_comumn.cql_type, 'int') + assert mv_alltime_fouls_comumn.cql_type == 'int' @lessthancass30 def test_base_table_type_alter_mv(self): @@ -2415,13 +2396,13 @@ def test_base_table_type_alter_mv(self): WITH CLUSTERING ORDER BY (score DESC, user ASC, day ASC)""".format(self.keyspace_name) self.session.execute(create_mv) - self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) + assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 alter_scores = """ALTER TABLE {0}.scores ALTER score TYPE blob""".format((self.keyspace_name)) self.session.execute(alter_scores) - self.assertEqual(len(self.cluster.metadata.keyspaces[self.keyspace_name].views), 1) + assert len(self.cluster.metadata.keyspaces[self.keyspace_name].views) == 1 score_column = self.cluster.metadata.keyspaces[self.keyspace_name].tables['scores'].columns['score'] - self.assertEqual(score_column.cql_type, 'blob') + assert score_column.cql_type == 'blob' # until CASSANDRA-9920+CASSANDRA-10500 MV updates are only available later with an async event for i in range(10): @@ -2430,7 +2411,7 @@ def test_base_table_type_alter_mv(self): break time.sleep(.2) - self.assertEqual(score_mv_column.cql_type, 'blob') + assert score_mv_column.cql_type == 'blob' def test_metadata_with_quoted_identifiers(self): """ @@ -2466,48 +2447,48 @@ def test_metadata_with_quoted_identifiers(self): t1_table = self.cluster.metadata.keyspaces[self.keyspace_name].tables['t1'] mv = self.cluster.metadata.keyspaces[self.keyspace_name].views['mv1'] - self.assertIsNotNone(t1_table.views["mv1"]) - self.assertIsNotNone(len(t1_table.views), 1) + assert t1_table.views["mv1"] is not None + assert len(t1_table.views) is not None, 1 # Validate partition key, and not null - self.assertEqual(len(t1_table.partition_key), 1) - self.assertIsNotNone(t1_table.columns['theKey']) - self.assertTrue(t1_table.columns['theKey'], t1_table.partition_key[0]) + assert len(t1_table.partition_key) == 1 + assert t1_table.columns['theKey'] is not None + assert t1_table.columns['theKey'], t1_table.partition_key[0] # Validate clustering key column - self.assertEqual(len(t1_table.clustering_key), 1) - self.assertIsNotNone(t1_table.columns['the;Clustering']) - self.assertTrue(t1_table.columns['the;Clustering'], t1_table.clustering_key[0]) + assert len(t1_table.clustering_key) == 1 + assert t1_table.columns['the;Clustering'] is not None + assert t1_table.columns['the;Clustering'], t1_table.clustering_key[0] # Validate regular column - self.assertIsNotNone(t1_table.columns['the Value']) + assert t1_table.columns['the Value'] is not None # Validate basic mv information - self.assertEqual(mv.keyspace_name, self.keyspace_name) - self.assertEqual(mv.name, "mv1") - self.assertEqual(mv.base_table_name, "t1") - self.assertFalse(mv.include_all_columns) + assert mv.keyspace_name == self.keyspace_name + assert mv.name == "mv1" + assert mv.base_table_name == "t1" + assert not mv.include_all_columns # Validate that all columns are preset and correct mv_columns = list(mv.columns.values()) - self.assertEqual(len(mv_columns), 3) + assert len(mv_columns) == 3 theKey_column = mv_columns[0] - self.assertIsNotNone(theKey_column) - self.assertEqual(theKey_column.name, 'theKey') - self.assertEqual(theKey_column, mv.partition_key[0]) + assert theKey_column is not None + assert theKey_column.name == 'theKey' + assert theKey_column == mv.partition_key[0] cluster_column = mv_columns[1] - self.assertIsNotNone(cluster_column) - self.assertEqual(cluster_column.name, 'the;Clustering') - self.assertEqual(cluster_column.name, mv.clustering_key[0].name) - self.assertEqual(cluster_column.table, mv.clustering_key[0].table) - self.assertEqual(cluster_column.is_static, mv.clustering_key[0].is_static) - self.assertEqual(cluster_column.is_reversed, mv.clustering_key[0].is_reversed) + assert cluster_column is not None + assert cluster_column.name == 'the;Clustering' + assert cluster_column.name == mv.clustering_key[0].name + assert cluster_column.table == mv.clustering_key[0].table + assert cluster_column.is_static == mv.clustering_key[0].is_static + assert cluster_column.is_reversed == mv.clustering_key[0].is_reversed value_column = mv_columns[2] - self.assertIsNotNone(value_column) - self.assertEqual(value_column.name, 'the Value') + assert value_column is not None + assert value_column.name == 'the Value' class GroupPerHost(BasicSharedKeyspaceUnitTestCase): @@ -2548,14 +2529,14 @@ def test_group_keys_by_host(self): def _assert_group_keys_by_host(self, keys, table_name, stmt): keys_per_host = group_keys_by_replica(self.session, self.ks_name, table_name, keys) - self.assertNotIn(NO_VALID_REPLICA, keys_per_host) + assert NO_VALID_REPLICA not in keys_per_host prepared_stmt = self.session.prepare(stmt) for key in keys: routing_key = prepared_stmt.bind(key).routing_key hosts = self.cluster.metadata.get_replicas(self.ks_name, routing_key) - self.assertEqual(1, len(hosts)) # RF is 1 for this keyspace - self.assertIn(key, keys_per_host[hosts[0]]) + assert 1 == len(hosts) # RF is 1 for this keyspace + assert key in keys_per_host[hosts[0]] class VirtualKeypaceTest(BasicSharedKeyspaceUnitTestCase): @@ -2564,12 +2545,6 @@ class VirtualKeypaceTest(BasicSharedKeyspaceUnitTestCase): def test_existing_keyspaces_have_correct_virtual_tags(self): for name, ks in self.cluster.metadata.keyspaces.items(): if name in self.virtual_ks_names: - self.assertTrue( - ks.virtual, - 'incorrect .virtual value for {}'.format(name) - ) + assert ks.virtual, 'incorrect .virtual value for {}'.format(name) else: - self.assertFalse( - ks.virtual, - 'incorrect .virtual value for {}'.format(name) - ) + assert not ks.virtual, 'incorrect .virtual value for {}'.format(name) diff --git a/tests/integration/standard/test_metrics.py b/tests/integration/standard/test_metrics.py index 4b9ddb1351..8ccd278ee4 100644 --- a/tests/integration/standard/test_metrics.py +++ b/tests/integration/standard/test_metrics.py @@ -29,6 +29,7 @@ from tests.integration import BasicSharedKeyspaceUnitTestCaseRF3WM, BasicExistingKeyspaceUnitTestCase, local import pprint as pp +import pytest def setup_module(): @@ -70,14 +71,14 @@ def test_connection_error(self): # Ensure the nodes are actually down query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) # both exceptions can happen depending on when the connection has been detected as defunct - with self.assertRaises((NoHostAvailable, ConnectionShutdown)): + with pytest.raises((NoHostAvailable, ConnectionShutdown)): self.session.execute(query) finally: get_cluster().start(wait_for_binary_proto=True, wait_other_notice=True) # Give some time for the cluster to come back up, for the next test time.sleep(5) - self.assertGreater(self.cluster.metrics.stats.connection_errors, 0) + assert self.cluster.metrics.stats.connection_errors > 0 def test_write_timeout(self): """ @@ -92,7 +93,7 @@ def test_write_timeout(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(self.session, query) - self.assertTrue(results) + assert results # Pause node so it shows as unreachable to coordinator get_node(1).pause() @@ -100,9 +101,9 @@ def test_write_timeout(self): try: # Test write query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(WriteTimeout): + with pytest.raises(WriteTimeout): self.session.execute(query, timeout=None) - self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) + assert 1 == self.cluster.metrics.stats.write_timeouts finally: get_node(1).resume() @@ -121,7 +122,7 @@ def test_read_timeout(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(self.session, query) - self.assertTrue(results) + assert results # Pause node so it shows as unreachable to coordinator get_node(1).pause() @@ -129,9 +130,9 @@ def test_read_timeout(self): try: # Test read query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(ReadTimeout): + with pytest.raises(ReadTimeout): self.session.execute(query, timeout=None) - self.assertEqual(1, self.cluster.metrics.stats.read_timeouts) + assert 1 == self.cluster.metrics.stats.read_timeouts finally: get_node(1).resume() @@ -149,7 +150,7 @@ def test_unavailable(self): # Assert read query = SimpleStatement("SELECT * FROM test WHERE k=1", consistency_level=ConsistencyLevel.ALL) results = execute_until_pass(self.session, query) - self.assertTrue(results) + assert results # Stop node gracefully # Sometimes this commands continues with the other nodes having not noticed @@ -159,15 +160,15 @@ def test_unavailable(self): try: # Test write query = SimpleStatement("INSERT INTO test (k, v) VALUES (2, 2)", consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(Unavailable): + with pytest.raises(Unavailable): self.session.execute(query) - self.assertEqual(self.cluster.metrics.stats.unavailables, 1) + assert self.cluster.metrics.stats.unavailables == 1 # Test write query = SimpleStatement("SELECT * FROM test", consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(Unavailable): + with pytest.raises(Unavailable): self.session.execute(query, timeout=None) - self.assertEqual(self.cluster.metrics.stats.unavailables, 2) + assert self.cluster.metrics.stats.unavailables == 2 finally: get_node(1).start(wait_other_notice=True, wait_for_binary_proto=True) # Give some time for the cluster to come back up, for the next test @@ -206,7 +207,7 @@ def test_metrics_per_cluster(self): ) cluster2.connect(self.ks_name, wait_for_all_pools=True) - self.assertEqual(len(cluster2.metadata.all_hosts()), 3) + assert len(cluster2.metadata.all_hosts()) == 3 query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) self.session.execute(query) @@ -217,7 +218,7 @@ def test_metrics_per_cluster(self): try: # Test write query = SimpleStatement("INSERT INTO {0}.{0} (k, v) VALUES (2, 2)".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) - with self.assertRaises(WriteTimeout): + with pytest.raises(WriteTimeout): self.session.execute(query, timeout=None) finally: get_node(1).resume() @@ -229,19 +230,19 @@ def test_metrics_per_cluster(self): stats_cluster2 = cluster2.metrics.get_stats() # Test direct access to stats - self.assertEqual(1, self.cluster.metrics.stats.write_timeouts) - self.assertEqual(0, cluster2.metrics.stats.write_timeouts) + assert 1 == self.cluster.metrics.stats.write_timeouts + assert 0 == cluster2.metrics.stats.write_timeouts # Test direct access to a child stats - self.assertNotEqual(0.0, self.cluster.metrics.request_timer['mean']) - self.assertEqual(0.0, cluster2.metrics.request_timer['mean']) + assert 0.0 != self.cluster.metrics.request_timer['mean'] + assert 0.0 == cluster2.metrics.request_timer['mean'] # Test access via metrics.get_stats() - self.assertNotEqual(0.0, stats_cluster1['request_timer']['mean']) - self.assertEqual(0.0, stats_cluster2['request_timer']['mean']) + assert 0.0 != stats_cluster1['request_timer']['mean'] + assert 0.0 == stats_cluster2['request_timer']['mean'] # Test access by stats_name - self.assertEqual(0.0, scales.getStats()['cluster2-metrics']['request_timer']['mean']) + assert 0.0 == scales.getStats()['cluster2-metrics']['request_timer']['mean'] cluster2.shutdown() @@ -269,7 +270,7 @@ def test_duplicate_metrics_per_cluster(self): # Ensure duplicate metric names are not allowed cluster2.metrics.set_stats_name("appcluster") cluster2.metrics.set_stats_name("appcluster") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cluster3.metrics.set_stats_name("appcluster") cluster3.metrics.set_stats_name("devops") @@ -285,12 +286,12 @@ def test_duplicate_metrics_per_cluster(self): query = SimpleStatement("SELECT * FROM {0}.{0}".format(self.ks_name), consistency_level=ConsistencyLevel.ALL) session3.execute(query) - self.assertEqual(cluster2.metrics.get_stats()['request_timer']['count'], 10) - self.assertEqual(cluster3.metrics.get_stats()['request_timer']['count'], 5) + assert cluster2.metrics.get_stats()['request_timer']['count'] == 10 + assert cluster3.metrics.get_stats()['request_timer']['count'] == 5 # Check scales to ensure they are appropriately named - self.assertTrue("appcluster" in scales._Stats.stats.keys()) - self.assertTrue("devops" in scales._Stats.stats.keys()) + assert "appcluster" in scales._Stats.stats.keys() + assert "devops" in scales._Stats.stats.keys() cluster2.shutdown() cluster3.shutdown() @@ -385,8 +386,8 @@ def test_metrics_per_cluster(self): except SyntaxException: continue - self.assertTrue(self.wait_for_count(ra, 10)) - self.assertTrue(self.wait_for_count(ra, 3, error=True)) + assert self.wait_for_count(ra, 10) + assert self.wait_for_count(ra, 3, error=True) ra.remove_ra(self.session) diff --git a/tests/integration/standard/test_policies.py b/tests/integration/standard/test_policies.py index faa21efb02..0c84fd06be 100644 --- a/tests/integration/standard/test_policies.py +++ b/tests/integration/standard/test_policies.py @@ -62,7 +62,7 @@ def test_predicate_changes(self): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - self.assertEqual(queried_hosts, single_host) + assert queried_hosts == single_host external_event = False futures = session.update_created_pools() @@ -72,7 +72,7 @@ def test_predicate_changes(self): for _ in range(10): response = session.execute("SELECT * from system.local WHERE key='local'") queried_hosts.update(response.response_future.attempted_hosts) - self.assertEqual(queried_hosts, all_hosts) + assert queried_hosts == all_hosts class WhiteListRoundRobinPolicyTests(unittest.TestCase): @@ -89,7 +89,7 @@ def test_only_connects_to_subset(self): response = session.execute("SELECT * from system.local WHERE key='local'", execution_profile="white_list") queried_hosts.update(response.response_future.attempted_hosts) queried_hosts = set(host.address for host in queried_hosts) - self.assertEqual(queried_hosts, only_connect_hosts) + assert queried_hosts == only_connect_hosts class ExponentialRetryPolicyTests(unittest.TestCase): diff --git a/tests/integration/standard/test_prepared_statements.py b/tests/integration/standard/test_prepared_statements.py index d413a4dc95..68a704cd77 100644 --- a/tests/integration/standard/test_prepared_statements.py +++ b/tests/integration/standard/test_prepared_statements.py @@ -27,6 +27,7 @@ BasicSharedKeyspaceUnitTestCase) import logging +import pytest LOG = logging.getLogger(__name__) @@ -80,7 +81,7 @@ def test_basic(self): INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind(('a', 'b', 'c')) self.session.execute(bound) @@ -89,11 +90,11 @@ def test_basic(self): """ SELECT * FROM cf0 WHERE a=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind(('a')) results = self.session.execute(bound) - self.assertEqual(results, [('a', 'b', 'c')]) + assert results == [('a', 'b', 'c')] # test with new dict binding prepared = self.session.prepare( @@ -101,7 +102,7 @@ def test_basic(self): INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind({ 'a': 'x', 'b': 'y', @@ -115,11 +116,11 @@ def test_basic(self): SELECT * FROM cf0 WHERE a=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind({'a': 'x'}) results = self.session.execute(bound) - self.assertEqual(results, [('x', 'y', 'z')]) + assert results == [('x', 'y', 'z')] def test_missing_primary_key(self): """ @@ -133,12 +134,14 @@ def _run_missing_primary_key(self, session): statement_to_prepare = """INSERT INTO test3rf.test (v) VALUES (?)""" # logic needed work with changes in CASSANDRA-6237 if self.cass_version[0] >= (3, 0, 0): - self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + with pytest.raises(InvalidRequest): + session.prepare(statement_to_prepare) else: prepared = session.prepare(statement_to_prepare) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1,)) - self.assertRaises(InvalidRequest, session.execute, bound) + with pytest.raises(InvalidRequest): + session.execute(bound) def test_missing_primary_key_dicts(self): """ @@ -152,12 +155,14 @@ def _run_missing_primary_key_dicts(self, session): statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES (?)""" # logic needed work with changes in CASSANDRA-6237 if self.cass_version[0] >= (3, 0, 0): - self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + with pytest.raises(InvalidRequest): + session.prepare(statement_to_prepare) else: prepared = session.prepare(statement_to_prepare) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind({'v': 1}) - self.assertRaises(InvalidRequest, session.execute, bound) + with pytest.raises(InvalidRequest): + session.execute(bound) def test_too_many_bind_values(self): """ @@ -169,11 +174,13 @@ def _run_too_many_bind_values(self, session): statement_to_prepare = """ INSERT INTO test3rf.test (v) VALUES (?)""" # logic needed work with changes in CASSANDRA-6237 if self.cass_version[0] >= (2, 2, 8): - self.assertRaises(InvalidRequest, session.prepare, statement_to_prepare) + with pytest.raises(InvalidRequest): + session.prepare(statement_to_prepare) else: prepared = session.prepare(statement_to_prepare) - self.assertIsInstance(prepared, PreparedStatement) - self.assertRaises(ValueError, prepared.bind, (1, 2)) + assert isinstance(prepared, PreparedStatement) + with pytest.raises(ValueError): + prepared.bind((1, 2)) def test_imprecise_bind_values_dicts(self): """ @@ -186,7 +193,7 @@ def test_imprecise_bind_values_dicts(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) # too many values is ok - others are ignored prepared.bind({'k': 1, 'v': 2, 'v2': 3}) @@ -194,18 +201,21 @@ def test_imprecise_bind_values_dicts(self): # right number, but one does not belong if PROTOCOL_VERSION < 4: # pre v4, the driver bails with key error when 'v' is found missing - self.assertRaises(KeyError, prepared.bind, {'k': 1, 'v2': 3}) + with pytest.raises(KeyError): + prepared.bind({'k': 1, 'v2': 3}) else: # post v4, the driver uses UNSET_VALUE for 'v' and 'v2' is ignored prepared.bind({'k': 1, 'v2': 3}) # also catch too few variables with dicts - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) if PROTOCOL_VERSION < 4: - self.assertRaises(KeyError, prepared.bind, {}) + with pytest.raises(KeyError): + prepared.bind({}) else: # post v4, the driver attempts to use UNSET_VALUE for unspecified keys - self.assertRaises(ValueError, prepared.bind, {}) + with pytest.raises(ValueError): + prepared.bind({}) def test_none_values(self): """ @@ -217,7 +227,7 @@ def test_none_values(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) self.session.execute(bound) @@ -225,11 +235,11 @@ def test_none_values(self): """ SELECT * FROM test3rf.test WHERE k=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1,)) results = self.session.execute(bound) - self.assertEqual(results.one().v, None) + assert results.one().v == None def test_unset_values(self): """ @@ -272,9 +282,10 @@ def test_unset_values(self): for params, expected in bind_expected: self.session.execute(insert, params) results = self.session.execute(select, (0,)) - self.assertEqual(results.one(), expected) + assert results.one() == expected - self.assertRaises(ValueError, self.session.execute, select, (UNSET_VALUE, 0, 0)) + with pytest.raises(ValueError): + self.session.execute(select, (UNSET_VALUE, 0, 0)) def test_no_meta(self): @@ -283,7 +294,7 @@ def test_no_meta(self): INSERT INTO test3rf.test (k, v) VALUES (0, 0) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind(None) bound.consistency_level = ConsistencyLevel.ALL self.session.execute(bound) @@ -292,12 +303,12 @@ def test_no_meta(self): """ SELECT * FROM test3rf.test WHERE k=0 """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind(None) bound.consistency_level = ConsistencyLevel.ALL results = self.session.execute(bound) - self.assertEqual(results.one().v, 0) + assert results.one().v == 0 def test_none_values_dicts(self): """ @@ -310,7 +321,7 @@ def test_none_values_dicts(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind({'k': 1, 'v': None}) self.session.execute(bound) @@ -318,11 +329,11 @@ def test_none_values_dicts(self): """ SELECT * FROM test3rf.test WHERE k=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind({'k': 1}) results = self.session.execute(bound) - self.assertEqual(results.one().v, None) + assert results.one().v == None def test_async_binding(self): """ @@ -334,7 +345,7 @@ def test_async_binding(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, (873, None)) future.result() @@ -342,11 +353,11 @@ def test_async_binding(self): """ SELECT * FROM test3rf.test WHERE k=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, (873,)) results = future.result() - self.assertEqual(results.one().v, None) + assert results.one().v == None def test_async_binding_dicts(self): """ @@ -357,7 +368,7 @@ def test_async_binding_dicts(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, {'k': 873, 'v': None}) future.result() @@ -365,11 +376,11 @@ def test_async_binding_dicts(self): """ SELECT * FROM test3rf.test WHERE k=? """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) future = self.session.execute_async(prepared, {'k': 873}) results = future.result() - self.assertEqual(results.one().v, None) + assert results.one().v == None def test_raise_error_on_prepared_statement_execution_dropped_table(self): """ @@ -392,7 +403,7 @@ def test_raise_error_on_prepared_statement_execution_dropped_table(self): prepared = self.session.prepare("SELECT * FROM test3rf.error_test WHERE k=?") self.session.execute("DROP TABLE test3rf.error_test") - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): self.session.execute(prepared, [0]) @unittest.skipIf((CASSANDRA_VERSION >= Version('3.11.12') and CASSANDRA_VERSION < Version('4.0')) or \ @@ -410,9 +421,9 @@ def test_fail_if_different_query_id_on_reprepare(self): self.session.execute("DROP TABLE {}.foo".format(keyspace)) self.session.execute("CREATE TABLE {}.foo(k int PRIMARY KEY)".format(keyspace)) self.session.execute("USE {}".format(keyspace)) - with self.assertRaises(DriverException) as e: + with pytest.raises(DriverException) as e: self.session.execute(prepared, [0]) - self.assertIn("ID mismatch", str(e.exception)) + assert "ID mismatch" in str(e.value) @greaterthanorequalcass40 @@ -441,19 +452,19 @@ def test_invalidated_result_metadata(self): """ wildcard_prepared = self.session.prepare("SELECT * FROM {}".format(self.table_name)) original_result_metadata = wildcard_prepared.result_metadata - self.assertEqual(len(original_result_metadata), 3) + assert len(original_result_metadata) == 3 r = self.session.execute(wildcard_prepared) - self.assertEqual(r[0], (1, 1, 1)) + assert r[0] == (1, 1, 1) self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) # Get a bunch of requests in the pipeline with varying states of result_meta, reprepare, resolved futures = set(self.session.execute_async(wildcard_prepared.bind(None)) for _ in range(200)) for f in futures: - self.assertEqual(f.result()[0], (1, 1)) + assert f.result()[0] == (1, 1) - self.assertIsNot(wildcard_prepared.result_metadata, original_result_metadata) + assert wildcard_prepared.result_metadata is not original_result_metadata def test_prepared_id_is_update(self): """ @@ -468,7 +479,7 @@ def test_prepared_id_is_update(self): """ prepared_statement = self.session.prepare("SELECT * from {} WHERE a = ?".format(self.table_name)) id_before = prepared_statement.result_metadata_id - self.assertEqual(len(prepared_statement.result_metadata), 3) + assert len(prepared_statement.result_metadata) == 3 self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) bound_statement = prepared_statement.bind((1, )) @@ -476,8 +487,8 @@ def test_prepared_id_is_update(self): id_after = prepared_statement.result_metadata_id - self.assertNotEqual(id_before, id_after) - self.assertEqual(len(prepared_statement.result_metadata), 4) + assert id_before != id_after + assert len(prepared_statement.result_metadata) == 4 def test_prepared_id_is_updated_across_pages(self): """ @@ -491,12 +502,12 @@ def test_prepared_id_is_updated_across_pages(self): """ prepared_statement = self.session.prepare("SELECT * from {}".format(self.table_name)) id_before = prepared_statement.result_metadata_id - self.assertEqual(len(prepared_statement.result_metadata), 3) + assert len(prepared_statement.result_metadata) == 3 prepared_statement.fetch_size = 2 result = self.session.execute(prepared_statement.bind((None))) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages self.session.execute("ALTER TABLE {} ADD c int".format(self.table_name)) @@ -505,9 +516,9 @@ def test_prepared_id_is_updated_across_pages(self): id_after = prepared_statement.result_metadata_id - self.assertEqual(result_set, expected_result_set) - self.assertNotEqual(id_before, id_after) - self.assertEqual(len(prepared_statement.result_metadata), 4) + assert result_set == expected_result_set + assert id_before != id_after + assert len(prepared_statement.result_metadata) == 4 def test_prepare_id_is_updated_across_session(self): """ @@ -524,7 +535,7 @@ def test_prepare_id_is_updated_across_session(self): stm = "SELECT * from {} WHERE a = ?".format(self.table_name) one_prepared_stm = one_session.prepare(stm) - self.assertEqual(len(one_prepared_stm.result_metadata), 3) + assert len(one_prepared_stm.result_metadata) == 3 one_id_before = one_prepared_stm.result_metadata_id @@ -532,8 +543,8 @@ def test_prepare_id_is_updated_across_session(self): one_session.execute(one_prepared_stm, (1, )) one_id_after = one_prepared_stm.result_metadata_id - self.assertNotEqual(one_id_before, one_id_after) - self.assertEqual(len(one_prepared_stm.result_metadata), 4) + assert one_id_before != one_id_after + assert len(one_prepared_stm.result_metadata) == 4 def test_not_reprepare_invalid_statements(self): """ @@ -546,7 +557,7 @@ def test_not_reprepare_invalid_statements(self): prepared_statement = self.session.prepare( "SELECT a, b, d FROM {} WHERE a = ?".format(self.table_name)) self.session.execute("ALTER TABLE {} DROP d".format(self.table_name)) - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): self.session.execute(prepared_statement.bind((1, ))) def test_id_is_not_updated_conditional_v4(self): @@ -584,12 +595,9 @@ def _test_updated_conditional(self, session, value): LOG.debug('initial result_metadata_id: {}'.format(first_id)) def check_result_and_metadata(expected): - self.assertEqual( - session.execute(prepared_statement, (value, value, value)).one(), - expected - ) - self.assertEqual(prepared_statement.result_metadata_id, first_id) - self.assertIsNone(prepared_statement.result_metadata) + assert session.execute(prepared_statement, (value, value, value)).one() == expected + assert prepared_statement.result_metadata_id == first_id + assert prepared_statement.result_metadata is None # Successful conditional update check_result_and_metadata((True,)) diff --git a/tests/integration/standard/test_query.py b/tests/integration/standard/test_query.py index a4d1b083bf..a16a34233a 100644 --- a/tests/integration/standard/test_query.py +++ b/tests/integration/standard/test_query.py @@ -29,6 +29,7 @@ USE_CASS_EXTERNAL, greaterthanorequalcass40, TestCluster, xfail_scylla from tests import notwindows from tests.integration import greaterthanorequalcass30, get_node +from tests.util import assertListEqual import time import random @@ -63,12 +64,12 @@ def test_query(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """.format(self.keyspace_name)) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) - self.assertIsInstance(bound, BoundStatement) - self.assertEqual(2, len(bound.values)) + assert isinstance(bound, BoundStatement) + assert 2 == len(bound.values) self.session.execute(bound) - self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') + assert bound.routing_key == b'\x00\x00\x00\x01' def test_trace_prints_okay(self): """ @@ -81,7 +82,7 @@ def test_trace_prints_okay(self): # Ensure this does not throw an exception trace = rs.get_query_trace() - self.assertTrue(trace.events) + assert trace.events str(trace) for event in trace.events: str(event) @@ -98,9 +99,9 @@ def test_row_error_message(self): self.session.execute("CREATE TABLE {0}.{1} (k int PRIMARY KEY, v timestamp)".format(self.keyspace_name,self.function_table_name)) ss = SimpleStatement("INSERT INTO {0}.{1} (k, v) VALUES (1, 1000000000000000)".format(self.keyspace_name, self.function_table_name)) self.session.execute(ss) - with self.assertRaises(DriverException) as context: + with pytest.raises(DriverException) as context: self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) - self.assertIn("Failed decoding result column", str(context.exception)) + assert "Failed decoding result column" in str(context.value) def test_trace_id_to_resultset(self): @@ -109,14 +110,14 @@ def test_trace_id_to_resultset(self): # future should have the current trace rs = future.result() future_trace = future.get_query_trace() - self.assertIsNotNone(future_trace) + assert future_trace is not None rs_trace = rs.get_query_trace() - self.assertEqual(rs_trace, future_trace) - self.assertTrue(rs_trace.events) - self.assertEqual(len(rs_trace.events), len(future_trace.events)) + assert rs_trace == future_trace + assert rs_trace.events + assert len(rs_trace.events) == len(future_trace.events) - self.assertListEqual([rs_trace], rs.get_all_query_traces()) + assertListEqual([rs_trace], rs.get_all_query_traces()) def test_trace_ignores_row_factory(self): with TestCluster( @@ -129,7 +130,7 @@ def test_trace_ignores_row_factory(self): # Ensure this does not throw an exception trace = rs.get_query_trace() - self.assertTrue(trace.events) + assert trace.events str(trace) for event in trace.events: str(event) @@ -170,8 +171,8 @@ def test_client_ip_in_trace(self): pat = re.compile(r'127.0.0.\d{1,3}') # Ensure that ip is set - self.assertIsNotNone(client_ip, "Client IP was not set in trace with C* >= 2.2") - self.assertTrue(pat.match(client_ip), "Client IP from trace did not match the expected value") + assert client_ip is not None, "Client IP was not set in trace with C* >= 2.2" + assert pat.match(client_ip), "Client IP from trace did not match the expected value" def test_trace_cl(self): """ @@ -186,18 +187,18 @@ def test_trace_cl(self): statement = SimpleStatement(query) response_future = self.session.execute_async(statement, trace=True) response_future.result() - with self.assertRaises(Unavailable): + with pytest.raises(Unavailable): response_future.get_query_trace(query_cl=ConsistencyLevel.THREE) # Try again with a smattering of other CL's - self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id) + assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.TWO).trace_id is not None response_future = self.session.execute_async(statement, trace=True) response_future.result() - self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id) + assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ONE).trace_id is not None response_future = self.session.execute_async(statement, trace=True) response_future.result() - with self.assertRaises(InvalidRequest): - self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id) - self.assertIsNotNone(response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id) + with pytest.raises(InvalidRequest): + assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.ANY).trace_id is not None + assert response_future.get_query_trace(max_wait=2.0, query_cl=ConsistencyLevel.QUORUM).trace_id is not None @notwindows def test_incomplete_query_trace(self): @@ -222,27 +223,28 @@ def test_incomplete_query_trace(self): response_future = self.session.execute_async("SELECT i FROM {0} WHERE k=0".format(self.keyspace_table_name), trace=True) response_future.result() - self.assertEqual(len(response_future._query_traces), 1) + assert len(response_future._query_traces) == 1 trace = response_future._query_traces[0] - self.assertTrue(self._wait_for_trace_to_populate(trace.trace_id)) + assert self._wait_for_trace_to_populate(trace.trace_id) # Delete trace duration from the session (this is what the driver polls for "complete") delete_statement = SimpleStatement("DELETE duration FROM system_traces.sessions WHERE session_id = {0}".format(trace.trace_id), consistency_level=ConsistencyLevel.ALL) self.session.execute(delete_statement) - self.assertTrue(self._wait_for_trace_to_delete(trace.trace_id)) + assert self._wait_for_trace_to_delete(trace.trace_id) # should raise because duration is not set - self.assertRaises(TraceUnavailable, trace.populate, max_wait=0.2, wait_for_complete=True) - self.assertFalse(trace.events) + with pytest.raises(TraceUnavailable): + trace.populate(max_wait=0.2, wait_for_complete=True) + assert not trace.events # should get the events with wait False trace.populate(wait_for_complete=False) - self.assertIsNone(trace.duration) - self.assertIsNotNone(trace.trace_id) - self.assertIsNotNone(trace.request_type) - self.assertIsNotNone(trace.parameters) - self.assertTrue(trace.events) # non-zero list len - self.assertIsNotNone(trace.started_at) + assert trace.duration is None + assert trace.trace_id is not None + assert trace.request_type is not None + assert trace.parameters is not None + assert trace.events # non-zero list len + assert trace.started_at is not None def _wait_for_trace_to_populate(self, trace_id): count = 0 @@ -283,17 +285,17 @@ def test_query_by_id(self): self.session.execute("insert into "+self.keyspace_name+"."+self.function_table_name+" (id, m) VALUES ( 1, {1: 'one', 2: 'two', 3:'three'})") results1 = self.session.execute("select id, m from {0}.{1}".format(self.keyspace_name, self.function_table_name)) - self.assertIsNotNone(results1.column_types) - self.assertEqual(results1.column_types[0].typename, 'int') - self.assertEqual(results1.column_types[1].typename, 'map') - self.assertEqual(results1.column_types[0].cassname, 'Int32Type') - self.assertEqual(results1.column_types[1].cassname, 'MapType') - self.assertEqual(len(results1.column_types[0].subtypes), 0) - self.assertEqual(len(results1.column_types[1].subtypes), 2) - self.assertEqual(results1.column_types[1].subtypes[0].typename, "int") - self.assertEqual(results1.column_types[1].subtypes[1].typename, "varchar") - self.assertEqual(results1.column_types[1].subtypes[0].cassname, "Int32Type") - self.assertEqual(results1.column_types[1].subtypes[1].cassname, "VarcharType") + assert results1.column_types is not None + assert results1.column_types[0].typename == 'int' + assert results1.column_types[1].typename == 'map' + assert results1.column_types[0].cassname == 'Int32Type' + assert results1.column_types[1].cassname == 'MapType' + assert len(results1.column_types[0].subtypes) == 0 + assert len(results1.column_types[1].subtypes) == 2 + assert results1.column_types[1].subtypes[0].typename == "int" + assert results1.column_types[1].subtypes[1].typename == "varchar" + assert results1.column_types[1].subtypes[0].cassname == "Int32Type" + assert results1.column_types[1].subtypes[1].cassname == "VarcharType" def test_column_names(self): """ @@ -319,9 +321,9 @@ def test_column_names(self): self.session.execute(create_table) result_set = self.session.execute("SELECT * FROM {0}.{1}".format(self.keyspace_name, self.function_table_name)) - self.assertIsNotNone(result_set.column_types) + assert result_set.column_types is not None - self.assertEqual(result_set.column_names, [u'user', u'game', u'year', u'month', u'day', u'score']) + assert result_set.column_names == [u'user', u'game', u'year', u'month', u'day', u'score'] @greaterthanorequalcass30 def test_basic_json_query(self): @@ -330,8 +332,8 @@ def test_basic_json_query(self): self.session.execute(insert_query) results = self.session.execute(json_query) - self.assertEqual(results.column_names, ["[json]"]) - self.assertEqual(results.one()[0], '{"k": 1, "v": 1}') + assert results.column_names == ["[json]"] + assert results.one()[0] == '{"k": 1, "v": 1}' def test_host_targeting_query(self): """ @@ -356,9 +358,9 @@ def test_host_targeting_query(self): future = self.session.execute_async(query, host=host, execution_profile=checkable_ep) future.result() # check we're using the selected host - self.assertEqual(host, future.coordinator_host) + assert host == future.coordinator_host # check that this bypasses the LBP - self.assertFalse(checkable_ep.load_balancing_policy.make_query_plan.called) + assert not checkable_ep.load_balancing_policy.make_query_plan.called class PreparedStatementTests(unittest.TestCase): @@ -379,9 +381,9 @@ def test_routing_key(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) - self.assertEqual(bound.routing_key, b'\x00\x00\x00\x01') + assert bound.routing_key == b'\x00\x00\x00\x01' def test_empty_routing_key_indexes(self): """ @@ -394,9 +396,9 @@ def test_empty_routing_key_indexes(self): """) prepared.routing_key_indexes = None - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) - self.assertEqual(bound.routing_key, None) + assert bound.routing_key == None def test_predefined_routing_key(self): """ @@ -408,10 +410,10 @@ def test_predefined_routing_key(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, None)) bound._set_routing_key('fake_key') - self.assertEqual(bound.routing_key, 'fake_key') + assert bound.routing_key == 'fake_key' def test_multiple_routing_key_indexes(self): """ @@ -421,15 +423,15 @@ def test_multiple_routing_key_indexes(self): """ INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) prepared.routing_key_indexes = [0, 1] bound = prepared.bind((1, 2)) - self.assertEqual(bound.routing_key, b'\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00') + assert bound.routing_key == b'\x00\x04\x00\x00\x00\x01\x00\x00\x04\x00\x00\x00\x02\x00' prepared.routing_key_indexes = [1, 0] bound = prepared.bind((1, 2)) - self.assertEqual(bound.routing_key, b'\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00') + assert bound.routing_key == b'\x00\x04\x00\x00\x00\x02\x00\x00\x04\x00\x00\x00\x01\x00' def test_bound_keyspace(self): """ @@ -440,9 +442,9 @@ def test_bound_keyspace(self): INSERT INTO test3rf.test (k, v) VALUES (?, ?) """) - self.assertIsInstance(prepared, PreparedStatement) + assert isinstance(prepared, PreparedStatement) bound = prepared.bind((1, 2)) - self.assertEqual(bound.keyspace, 'test3rf') + assert bound.keyspace == 'test3rf' class ForcedHostIndexPolicy(RoundRobinPolicy): @@ -490,15 +492,15 @@ def test_prepared_metadata_generation(self): session = cluster.connect() select_statement = session.prepare("SELECT * FROM system.local WHERE key='local'") if proto_version == 1: - self.assertEqual(select_statement.result_metadata, None) + assert select_statement.result_metadata == None else: - self.assertNotEqual(select_statement.result_metadata, None) + assert select_statement.result_metadata != None future = session.execute_async(select_statement) results = future.result() if base_line is None: base_line = results.one()._asdict().keys() else: - self.assertEqual(base_line, results.one()._asdict().keys()) + assert base_line == results.one()._asdict().keys() cluster.shutdown() @@ -522,7 +524,7 @@ def test_prepare_on_all_hosts(self): select_statement = session.prepare("SELECT k FROM test3rf.test WHERE k = ?") for host in clus.metadata.all_hosts(): session.execute(select_statement, (1, ), host=host) - self.assertEqual(2, mock_handler.get_message_count('debug', "Re-preparing")) + assert 2 == mock_handler.get_message_count('debug', "Re-preparing") def test_prepare_batch_statement(self): """ @@ -562,12 +564,12 @@ def test_prepare_batch_statement(self): session.execute(batch_statement) # To verify our test assumption that queries are getting re-prepared properly - self.assertEqual(1, mock_handler.get_message_count('debug', "Re-preparing")) + assert 1 == mock_handler.get_message_count('debug', "Re-preparing") select_results = session.execute(SimpleStatement("SELECT * FROM %s WHERE k = 1" % table, consistency_level=ConsistencyLevel.ALL)) first_row = select_results.one()[:2] - self.assertEqual((1, 2), first_row) + assert (1, 2) == first_row def test_prepare_batch_statement_after_alter(self): """ @@ -615,10 +617,10 @@ def test_prepare_batch_statement_after_alter(self): (4, None, 5, None, 6) ] - self.assertEqual(set(expected_results), set(select_results._current_rows)) + assert set(expected_results) == set(select_results._current_rows) # To verify our test assumption that queries are getting re-prepared properly - self.assertEqual(3, mock_handler.get_message_count('debug', "Re-preparing")) + assert 3 == mock_handler.get_message_count('debug', "Re-preparing") class PrintStatementTests(unittest.TestCase): @@ -632,8 +634,7 @@ def test_simple_statement(self): """ ss = SimpleStatement('SELECT * FROM test3rf.test', consistency_level=ConsistencyLevel.ONE) - self.assertEqual(str(ss), - '') + assert str(ss) == '' def test_prepared_statement(self): """ @@ -646,12 +647,10 @@ def test_prepared_statement(self): prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)') prepared.consistency_level = ConsistencyLevel.ONE - self.assertEqual(str(prepared), - '') + assert str(prepared) == '' bound = prepared.bind((1, 2)) - self.assertEqual(str(bound), - '') + assert str(bound) == '' cluster.shutdown() @@ -683,8 +682,8 @@ def confirm_results(self): keys.add(result.k) values.add(result.v) - self.assertEqual(set(range(10)), keys, msg=results) - self.assertEqual(set(range(10)), values, msg=results) + assert set(range(10)) == keys, results + assert set(range(10)) == values, results def test_string_statements(self): batch = BatchStatement(BatchType.LOGGED) @@ -745,9 +744,12 @@ def test_no_parameters(self): batch.add("INSERT INTO test3rf.test (k, v) VALUES (8, 8)", ()) batch.add("INSERT INTO test3rf.test (k, v) VALUES (9, 9)", ()) - self.assertRaises(ValueError, batch.add, prepared.bind([]), (1)) - self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2)) - self.assertRaises(ValueError, batch.add, prepared.bind([]), (1, 2, 3)) + with pytest.raises(ValueError): + batch.add(prepared.bind([]), (1)) + with pytest.raises(ValueError): + batch.add(prepared.bind([]), (1, 2)) + with pytest.raises(ValueError): + batch.add(prepared.bind([]), (1, 2, 3)) self.session.execute(batch) self.confirm_results() @@ -781,11 +783,13 @@ def test_too_many_statements(self): b = BatchStatement(batch_type=BatchType.UNLOGGED, consistency_level=ConsistencyLevel.ONE) # max + 1 raises b.add_all([ss] * max_statements, [None] * max_statements) - self.assertRaises(ValueError, b.add, ss) + with pytest.raises(ValueError): + b.add(ss) # also would have bombed trying to encode b._statements_and_parameters.append((False, ss.query_string, ())) - self.assertRaises(NoHostAvailable, self.session.execute, b) + with pytest.raises(NoHostAvailable): + self.session.execute(b) class SerialConsistencyTests(unittest.TestCase): @@ -810,22 +814,22 @@ def test_conditional_update(self): serial_consistency_level=ConsistencyLevel.SERIAL) # crazy test, but PYTHON-299 # TODO: expand to check more parameters get passed to statement, and on to messages - self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL) + assert statement.serial_consistency_level == ConsistencyLevel.SERIAL future = self.session.execute_async(statement) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertTrue(result) - self.assertFalse(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.SERIAL + assert result + assert not result.one().applied statement = SimpleStatement( "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0", serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) - self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + assert statement.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(statement) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertTrue(result) - self.assertTrue(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL + assert result + assert result.one().applied def test_conditional_update_with_prepared_statements(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") @@ -835,9 +839,9 @@ def test_conditional_update_with_prepared_statements(self): statement.serial_consistency_level = ConsistencyLevel.SERIAL future = self.session.execute_async(statement) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertTrue(result) - self.assertFalse(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.SERIAL + assert result + assert not result.one().applied statement = self.session.prepare( "UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") @@ -845,34 +849,36 @@ def test_conditional_update_with_prepared_statements(self): bound.serial_consistency_level = ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(bound) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertTrue(result) - self.assertTrue(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL + assert result + assert result.one().applied def test_conditional_update_with_batch_statements(self): self.session.execute("INSERT INTO test3rf.test (k, v) VALUES (0, 0)") statement = BatchStatement(serial_consistency_level=ConsistencyLevel.SERIAL) statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=1") - self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.SERIAL) + assert statement.serial_consistency_level == ConsistencyLevel.SERIAL future = self.session.execute_async(statement) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.SERIAL) - self.assertTrue(result) - self.assertFalse(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.SERIAL + assert result + assert not result.one().applied statement = BatchStatement(serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) statement.add("UPDATE test3rf.test SET v=1 WHERE k=0 IF v=0") - self.assertEqual(statement.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) + assert statement.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL future = self.session.execute_async(statement) result = future.result() - self.assertEqual(future.message.serial_consistency_level, ConsistencyLevel.LOCAL_SERIAL) - self.assertTrue(result) - self.assertTrue(result.one().applied) + assert future.message.serial_consistency_level == ConsistencyLevel.LOCAL_SERIAL + assert result + assert result.one().applied def test_bad_consistency_level(self): statement = SimpleStatement("foo") - self.assertRaises(ValueError, setattr, statement, 'serial_consistency_level', ConsistencyLevel.ONE) - self.assertRaises(ValueError, SimpleStatement, 'foo', serial_consistency_level=ConsistencyLevel.ONE) + with pytest.raises(ValueError): + setattr(statement, 'serial_consistency_level', ConsistencyLevel.ONE) + with pytest.raises(ValueError): + SimpleStatement('foo', serial_consistency_level=ConsistencyLevel.ONE) class LightweightTransactionTests(unittest.TestCase): @@ -939,16 +945,16 @@ def test_no_connection_refused_on_timeout(self): # In this case result is an exception exception_type = type(result).__name__ if exception_type == "NoHostAvailable": - self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) + pytest.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message) if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]: if type(result).__name__ in ["WriteTimeout", "WriteFailure"]: received_timeout = True continue - self.fail("Unexpected exception %s: %s" % (exception_type, result.message)) + pytest.fail("Unexpected exception %s: %s" % (exception_type, result.message)) # Make sure test passed - self.assertTrue(received_timeout) + assert received_timeout @xfail_scylla('Fails on Scylla with error `SERIAL/LOCAL_SERIAL consistency may only be requested for one partition at a time`') def test_was_applied_batch_stmt(self): @@ -975,7 +981,7 @@ def test_was_applied_batch_stmt(self): "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 1, 10);", "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 2, 10);"], [None] * 3) result = self.session.execute(batch_statement) - #self.assertTrue(result.was_applied) + #assert result.was_applied # Should fail since (0, 0, 10) have already been written # The non conditional insert shouldn't be written as well @@ -985,11 +991,11 @@ def test_was_applied_batch_stmt(self): "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 4, 10);", "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 4) result = self.session.execute(batch_statement) - self.assertFalse(result.was_applied) + assert not result.was_applied all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') # Verify the non conditional insert hasn't been inserted - self.assertEqual(len(all_rows.current_rows), 3) + assert len(all_rows.current_rows) == 3 # Should fail since (0, 0, 10) have already been written batch_statement = BatchStatement(batch_type) @@ -997,12 +1003,12 @@ def test_was_applied_batch_stmt(self): "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 3, 10) IF NOT EXISTS;", "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) result = self.session.execute(batch_statement) - self.assertFalse(result.was_applied) + assert not result.was_applied # Should fail since (0, 0, 10) have already been written batch_statement.add("INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 0, 10) IF NOT EXISTS;") result = self.session.execute(batch_statement) - self.assertFalse(result.was_applied) + assert not result.was_applied # Should succeed batch_statement = BatchStatement(batch_type) @@ -1011,11 +1017,11 @@ def test_was_applied_batch_stmt(self): "INSERT INTO test3rf.lwt_clustering (k, c, v) VALUES (0, 5, 10) IF NOT EXISTS;"], [None] * 3) result = self.session.execute(batch_statement) - self.assertTrue(result.was_applied) + assert result.was_applied all_rows = self.session.execute("SELECT * from test3rf.lwt_clustering", execution_profile='serial') for i, row in enumerate(all_rows): - self.assertEqual((0, i, 10), (row[0], row[1], row[2])) + assert (0, i, 10) == (row[0], row[1], row[2]) self.session.execute("TRUNCATE TABLE test3rf.lwt_clustering") @@ -1033,7 +1039,7 @@ def test_empty_batch_statement(self): """ batch_statement = BatchStatement() results = self.session.execute(batch_statement) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): results.was_applied @pytest.mark.xfail(reason='Skipping until PYTHON-943 is resolved') @@ -1052,7 +1058,7 @@ def test_was_applied_batch_string(self): APPLY batch; """ result = self.session.execute(batch_str) - self.assertFalse(result.was_applied) + assert not result.was_applied batch_str = """ BEGIN unlogged batch @@ -1062,7 +1068,7 @@ def test_was_applied_batch_string(self): APPLY batch; """ result = self.session.execute(batch_str) - self.assertTrue(result.was_applied) + assert result.was_applied class BatchStatementDefaultRoutingKeyTests(unittest.TestCase): @@ -1091,8 +1097,8 @@ def test_rk_from_bound(self): bound = self.prepared.bind((1, None)) batch = BatchStatement() batch.add(bound) - self.assertIsNotNone(batch.routing_key) - self.assertEqual(batch.routing_key, bound.routing_key) + assert batch.routing_key is not None + assert batch.routing_key == bound.routing_key def test_rk_from_simple(self): """ @@ -1100,8 +1106,8 @@ def test_rk_from_simple(self): """ batch = BatchStatement() batch.add(self.simple_statement) - self.assertIsNotNone(batch.routing_key) - self.assertEqual(batch.routing_key, self.simple_statement.routing_key) + assert batch.routing_key is not None + assert batch.routing_key == self.simple_statement.routing_key def test_inherit_first_rk_bound(self): """ @@ -1116,8 +1122,8 @@ def test_inherit_first_rk_bound(self): for i in range(3): batch.add(self.prepared, (i, i)) - self.assertIsNotNone(batch.routing_key) - self.assertEqual(batch.routing_key, bound.routing_key) + assert batch.routing_key is not None + assert batch.routing_key == bound.routing_key def test_inherit_first_rk_simple_statement(self): """ @@ -1132,8 +1138,8 @@ def test_inherit_first_rk_simple_statement(self): for i in range(10): batch.add(self.prepared, (i, i)) - self.assertIsNotNone(batch.routing_key) - self.assertEqual(batch.routing_key, self.simple_statement.routing_key) + assert batch.routing_key is not None + assert batch.routing_key == self.simple_statement.routing_key def test_inherit_first_rk_prepared_param(self): """ @@ -1146,8 +1152,8 @@ def test_inherit_first_rk_prepared_param(self): batch.add(bound) batch.add(self.simple_statement) - self.assertIsNotNone(batch.routing_key) - self.assertEqual(batch.routing_key, self.prepared.bind((1, 0)).routing_key) + assert batch.routing_key is not None + assert batch.routing_key == self.prepared.bind((1, 0)).routing_key @greaterthanorequalcass30 @@ -1243,73 +1249,73 @@ def test_mv_filtering(self): query_statement = SimpleStatement("SELECT * FROM {0}.alltimehigh WHERE game='Coup'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEqual(results.one().game, 'Coup') - self.assertEqual(results.one().year, 2015) - self.assertEqual(results.one().month, 5) - self.assertEqual(results.one().day, 1) - self.assertEqual(results.one().score, 4000) - self.assertEqual(results.one().user, "pcmanus") + assert results.one().game == 'Coup' + assert results.one().year == 2015 + assert results.one().month == 5 + assert results.one().day == 1 + assert results.one().score == 4000 + assert results.one().user == "pcmanus" # Test prepared statement and daily high filtering prepared_query = self.session.prepare("SELECT * FROM {0}.dailyhigh WHERE game=? AND year=? AND month=? and day=?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2)) results = self.session.execute(bound_query) - self.assertEqual(results.one().game, 'Coup') - self.assertEqual(results.one().year, 2015) - self.assertEqual(results.one().month, 6) - self.assertEqual(results.one().day, 2) - self.assertEqual(results.one().score, 2000) - self.assertEqual(results.one().user, "pcmanus") - - self.assertEqual(results[1].game, 'Coup') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 6) - self.assertEqual(results[1].day, 2) - self.assertEqual(results[1].score, 1000) - self.assertEqual(results[1].user, "tjake") + assert results.one().game == 'Coup' + assert results.one().year == 2015 + assert results.one().month == 6 + assert results.one().day == 2 + assert results.one().score == 2000 + assert results.one().user == "pcmanus" + + assert results[1].game == 'Coup' + assert results[1].year == 2015 + assert results[1].month == 6 + assert results[1].day == 2 + assert results[1].score == 1000 + assert results[1].user == "tjake" # Test montly high range queries prepared_query = self.session.prepare("SELECT * FROM {0}.monthlyhigh WHERE game=? AND year=? AND month=? and score >= ? and score <= ?".format(self.keyspace_name)) bound_query = prepared_query.bind(("Coup", 2015, 6, 2500, 3500)) results = self.session.execute(bound_query) - self.assertEqual(results.one().game, 'Coup') - self.assertEqual(results.one().year, 2015) - self.assertEqual(results.one().month, 6) - self.assertEqual(results.one().day, 20) - self.assertEqual(results.one().score, 3500) - self.assertEqual(results.one().user, "jbellis") - - self.assertEqual(results[1].game, 'Coup') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 6) - self.assertEqual(results[1].day, 9) - self.assertEqual(results[1].score, 2700) - self.assertEqual(results[1].user, "jmckenzie") - - self.assertEqual(results[2].game, 'Coup') - self.assertEqual(results[2].year, 2015) - self.assertEqual(results[2].month, 6) - self.assertEqual(results[2].day, 1) - self.assertEqual(results[2].score, 2500) - self.assertEqual(results[2].user, "iamaleksey") + assert results.one().game == 'Coup' + assert results.one().year == 2015 + assert results.one().month == 6 + assert results.one().day == 20 + assert results.one().score == 3500 + assert results.one().user == "jbellis" + + assert results[1].game == 'Coup' + assert results[1].year == 2015 + assert results[1].month == 6 + assert results[1].day == 9 + assert results[1].score == 2700 + assert results[1].user == "jmckenzie" + + assert results[2].game == 'Coup' + assert results[2].year == 2015 + assert results[2].month == 6 + assert results[2].day == 1 + assert results[2].score == 2500 + assert results[2].user == "iamaleksey" # Test filtered user high scores query_statement = SimpleStatement("SELECT * FROM {0}.filtereduserhigh WHERE game='Chess'".format(self.keyspace_name), consistency_level=ConsistencyLevel.QUORUM) results = self.session.execute(query_statement) - self.assertEqual(results.one().game, 'Chess') - self.assertEqual(results.one().year, 2015) - self.assertEqual(results.one().month, 6) - self.assertEqual(results.one().day, 21) - self.assertEqual(results.one().score, 3500) - self.assertEqual(results.one().user, "jbellis") + assert results.one().game == 'Chess' + assert results.one().year == 2015 + assert results.one().month == 6 + assert results.one().day == 21 + assert results.one().score == 3500 + assert results.one().user == "jbellis" - self.assertEqual(results[1].game, 'Chess') - self.assertEqual(results[1].year, 2015) - self.assertEqual(results[1].month, 1) - self.assertEqual(results[1].day, 25) - self.assertEqual(results[1].score, 3200) - self.assertEqual(results[1].user, "pcmanus") + assert results[1].game == 'Chess' + assert results[1].year == 2015 + assert results[1].month == 1 + assert results[1].day == 25 + assert results[1].score == 3200 + assert results[1].user == "pcmanus" class UnicodeQueryTest(BasicSharedKeyspaceUnitTestCase): @@ -1471,18 +1477,18 @@ def test_lower_protocol(self): # set on queries with protocol version 5 or higher. Consider setting Cluster.protocol_version to 5.',), # : ConnectionException('Host has been marked down or removed',), # : ConnectionException('Host has been marked down or removed',)}) - with self.assertRaises(NoHostAvailable): + with pytest.raises(NoHostAvailable): session.execute(simple_stmt) def _check_set_keyspace_in_statement(self, session): simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name), keyspace=self.ks_name) results = session.execute(simple_stmt) - self.assertEqual(results.one(), (1, 1)) + assert results.one() == (1, 1) simple_stmt = SimpleStatement("SELECT * from {}".format(self.table_name)) simple_stmt.keyspace = self.ks_name results = session.execute(simple_stmt) - self.assertEqual(results.one(), (1, 1)) + assert results.one() == (1, 1) @greaterthanorequalcass40 @@ -1507,8 +1513,8 @@ def confirm_results(self): keys.add(result.k) values.add(result.v) - self.assertEqual(set(range(10)), keys, msg=results) - self.assertEqual(set(range(10)), values, msg=results) + assert set(range(10)) == keys, results + assert set(range(10)) == values, results @greaterthanorequalcass40 @@ -1536,14 +1542,14 @@ def test_prepared_with_keyspace_explicit(self): prepared_statement = self.session.prepare(query, keyspace=self.ks_name) results = self.session.execute(prepared_statement, (1, )) - self.assertEqual(results.one(), (1, 1)) + assert results.one() == (1, 1) prepared_statement_alternative = self.session.prepare(query, keyspace=self.alternative_ks) - self.assertNotEqual(prepared_statement.query_id, prepared_statement_alternative.query_id) + assert prepared_statement.query_id != prepared_statement_alternative.query_id results = self.session.execute(prepared_statement_alternative, (2,)) - self.assertEqual(results.one(), (2, 2)) + assert results.one() == (2, 2) def test_reprepare_after_host_is_down(self): """ @@ -1570,13 +1576,13 @@ def test_reprepare_after_host_is_down(self): # We wait for cluster._prepare_all_queries to be called time.sleep(5) - self.assertEqual(1, mock_handler.get_message_count('debug', 'Preparing all known prepared statements')) + assert 1 == mock_handler.get_message_count('debug', 'Preparing all known prepared statements') results = self.session.execute(prepared_statement, (1,), execution_profile="only_first") - self.assertEqual(results.one(), (1, )) + assert results.one() == (1, ) results = self.session.execute(prepared_statement_alternative, (2,), execution_profile="only_first") - self.assertEqual(results.one(), (2, )) + assert results.one() == (2, ) def test_prepared_not_found(self): """ @@ -1600,7 +1606,7 @@ def test_prepared_not_found(self): for _ in range(10): results = session.execute(prepared_statement, (1, )) - self.assertEqual(results.one(), (1,)) + assert results.one() == (1,) def test_prepared_in_query_keyspace(self): """ @@ -1619,12 +1625,12 @@ def test_prepared_in_query_keyspace(self): query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) prepared_statement = session.prepare(query) results = session.execute(prepared_statement, (1,)) - self.assertEqual(results.one(), (1,)) + assert results.one() == (1,) query = "SELECT k from {}.{} WHERE k = ?".format(self.alternative_ks, self.table_name) prepared_statement = session.prepare(query) results = session.execute(prepared_statement, (2,)) - self.assertEqual(results.one(), (2,)) + assert results.one() == (2,) def test_prepared_in_query_keyspace_and_explicit(self): """ @@ -1641,9 +1647,9 @@ def test_prepared_in_query_keyspace_and_explicit(self): query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) prepared_statement = self.session.prepare(query, keyspace="system") results = self.session.execute(prepared_statement, (1,)) - self.assertEqual(results.one(), (1,)) + assert results.one() == (1,) query = "SELECT k from {}.{} WHERE k = ?".format(self.ks_name, self.table_name) prepared_statement = self.session.prepare(query, keyspace=self.alternative_ks) results = self.session.execute(prepared_statement, (1,)) - self.assertEqual(results.one(), (1,)) + assert results.one() == (1,) diff --git a/tests/integration/standard/test_query_paging.py b/tests/integration/standard/test_query_paging.py index 26c1ca0da6..28567d991b 100644 --- a/tests/integration/standard/test_query_paging.py +++ b/tests/integration/standard/test_query_paging.py @@ -20,6 +20,7 @@ from itertools import cycle, count from threading import Event +import pytest from cassandra import ConsistencyLevel from cassandra.cluster import EXEC_PROFILE_DEFAULT, ExecutionProfile @@ -27,6 +28,8 @@ from cassandra.policies import HostDistance from cassandra.query import SimpleStatement +from tests.util import assertSequenceEqual + def setup_module(): use_singledc() @@ -60,12 +63,12 @@ def test_paging(self): for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): self.session.default_fetch_size = fetch_size - self.assertEqual(100, len(list(self.session.execute("SELECT * FROM test3rf.test")))) + assert 100 == len(list(self.session.execute("SELECT * FROM test3rf.test"))) statement = SimpleStatement("SELECT * FROM test3rf.test") - self.assertEqual(100, len(list(self.session.execute(statement)))) + assert 100 == len(list(self.session.execute(statement))) - self.assertEqual(100, len(list(self.session.execute(prepared)))) + assert 100 == len(list(self.session.execute(prepared))) def test_paging_state(self): """ @@ -86,14 +89,14 @@ def test_paging_state(self): result_set = self.session.execute("SELECT * FROM test3rf.test") while(result_set.has_more_pages): for row in result_set.current_rows: - self.assertNotIn(row, list_all_results) + assert row not in list_all_results list_all_results.extend(result_set.current_rows) page_state = result_set.paging_state result_set = self.session.execute("SELECT * FROM test3rf.test", paging_state=page_state) if(len(result_set.current_rows) > 0): list_all_results.append(result_set.current_rows) - self.assertEqual(len(list_all_results), 100) + assert len(list_all_results) == 100 def test_paging_verify_writes(self): statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), @@ -111,8 +114,8 @@ def test_paging_verify_writes(self): result_array.add(result.k) result_set.add(result.v) - self.assertEqual(set(range(100)), result_array) - self.assertEqual(set([0]), result_set) + assert set(range(100)) == result_array + assert set([0]) == result_set statement = SimpleStatement("SELECT * FROM test3rf.test") results = self.session.execute(statement) @@ -122,8 +125,8 @@ def test_paging_verify_writes(self): result_array.add(result.k) result_set.add(result.v) - self.assertEqual(set(range(100)), result_array) - self.assertEqual(set([0]), result_set) + assert set(range(100)) == result_array + assert set([0]) == result_set results = self.session.execute(prepared) result_array = set() @@ -132,8 +135,8 @@ def test_paging_verify_writes(self): result_array.add(result.k) result_set.add(result.v) - self.assertEqual(set(range(100)), result_array) - self.assertEqual(set([0]), result_set) + assert set(range(100)) == result_array + assert set([0]) == result_set def test_paging_verify_with_composite_keys(self): ddl = ''' @@ -161,8 +164,8 @@ def test_paging_verify_with_composite_keys(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) statement = SimpleStatement("SELECT * FROM test3rf.test_paging_verify_2") results = self.session.execute(statement) @@ -172,8 +175,8 @@ def test_paging_verify_with_composite_keys(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) results = self.session.execute(prepared) result_array = [] @@ -182,8 +185,8 @@ def test_paging_verify_with_composite_keys(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) def test_async_paging(self): statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), @@ -194,12 +197,12 @@ def test_async_paging(self): for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): self.session.default_fetch_size = fetch_size - self.assertEqual(100, len(list(self.session.execute_async("SELECT * FROM test3rf.test").result()))) + assert 100 == len(list(self.session.execute_async("SELECT * FROM test3rf.test").result())) statement = SimpleStatement("SELECT * FROM test3rf.test") - self.assertEqual(100, len(list(self.session.execute_async(statement).result()))) + assert 100 == len(list(self.session.execute_async(statement).result())) - self.assertEqual(100, len(list(self.session.execute_async(prepared).result()))) + assert 100 == len(list(self.session.execute_async(prepared).result())) def test_async_paging_verify_writes(self): ddl = ''' @@ -227,8 +230,8 @@ def test_async_paging_verify_writes(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) statement = SimpleStatement("SELECT * FROM test3rf.test_async_paging_verify") results = self.session.execute_async(statement).result() @@ -238,8 +241,8 @@ def test_async_paging_verify_writes(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) results = self.session.execute_async(prepared).result() result_array = [] @@ -248,8 +251,8 @@ def test_async_paging_verify_writes(self): result_array.append(result.k2) value_array.append(result.v) - self.assertSequenceEqual(range(100), result_array) - self.assertSequenceEqual(range(1, 101), value_array) + assertSequenceEqual(range(100), result_array) + assertSequenceEqual(range(1, 101), value_array) def test_paging_callbacks(self): """ @@ -287,13 +290,13 @@ def handle_page(rows, future, counter, number_of_calls): def handle_error(err): event.set() - self.fail(err) + pytest.fail(err) future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), errback=handle_error) event.wait() - self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) - self.assertEqual(next(counter), 100) + assert next(number_of_calls) == 100 // fetch_size + 1 + assert next(counter) == 100 # simple statement future = self.session.execute_async(SimpleStatement("SELECT * FROM test3rf.test"), timeout=20) @@ -304,8 +307,8 @@ def handle_error(err): future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), errback=handle_error) event.wait() - self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) - self.assertEqual(next(counter), 100) + assert next(number_of_calls) == 100 // fetch_size + 1 + assert next(counter) == 100 # prepared statement future = self.session.execute_async(prepared, timeout=20) @@ -316,8 +319,8 @@ def handle_error(err): future.add_callbacks(callback=handle_page, callback_args=(future, counter, number_of_calls), errback=handle_error) event.wait() - self.assertEqual(next(number_of_calls), 100 // fetch_size + 1) - self.assertEqual(next(counter), 100) + assert next(number_of_calls) == 100 // fetch_size + 1 + assert next(counter) == 100 def test_concurrent_with_paging(self): statements_and_params = zip(cycle(["INSERT INTO test3rf.test (k, v) VALUES (%s, 0)"]), @@ -329,10 +332,10 @@ def test_concurrent_with_paging(self): for fetch_size in (2, 3, 7, 10, 99, 100, 101, 10000): self.session.default_fetch_size = fetch_size results = execute_concurrent_with_args(self.session, prepared, [None] * 10) - self.assertEqual(10, len(results)) + assert 10 == len(results) for (success, result) in results: - self.assertTrue(success) - self.assertEqual(100, len(list(result))) + assert success + assert 100 == len(list(result)) def test_fetch_size(self): """ @@ -346,66 +349,66 @@ def test_fetch_size(self): self.session.default_fetch_size = 10 result = self.session.execute(prepared, []) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages self.session.default_fetch_size = 2000 result = self.session.execute(prepared, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages self.session.default_fetch_size = None result = self.session.execute(prepared, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages self.session.default_fetch_size = 10 prepared.fetch_size = 2000 result = self.session.execute(prepared, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages prepared.fetch_size = None result = self.session.execute(prepared, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages prepared.fetch_size = 10 result = self.session.execute(prepared, []) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages prepared.fetch_size = 2000 bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages prepared.fetch_size = None bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages prepared.fetch_size = 10 bound = prepared.bind([]) result = self.session.execute(bound, []) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages bound.fetch_size = 2000 result = self.session.execute(bound, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages bound.fetch_size = None result = self.session.execute(bound, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages bound.fetch_size = 10 result = self.session.execute(bound, []) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages s = SimpleStatement("SELECT * FROM test3rf.test", fetch_size=None) result = self.session.execute(s, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages s = SimpleStatement("SELECT * FROM test3rf.test") result = self.session.execute(s, []) - self.assertTrue(result.has_more_pages) + assert result.has_more_pages s = SimpleStatement("SELECT * FROM test3rf.test") s.fetch_size = None result = self.session.execute(s, []) - self.assertFalse(result.has_more_pages) + assert not result.has_more_pages diff --git a/tests/integration/standard/test_rack_aware_policy.py b/tests/integration/standard/test_rack_aware_policy.py index 5d7a69642f..d2a358373d 100644 --- a/tests/integration/standard/test_rack_aware_policy.py +++ b/tests/integration/standard/test_rack_aware_policy.py @@ -66,24 +66,24 @@ def test_rack_aware(self): for i in range (10): bound = prepared.bind([i]) results = self.session.execute(bound) - self.assertEqual(results, [(i, i%5, i%2)]) + assert results == [(i, i%5, i%2)] coordinator = str(results.response_future.coordinator_host.endpoint) - self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"])) + assert coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"]) self.node2.stop(wait_other_notice=True, gently=True) for i in range (10): bound = prepared.bind([i]) results = self.session.execute(bound) - self.assertEqual(results, [(i, i%5, i%2)]) + assert results == [(i, i%5, i%2)] coordinator =str(results.response_future.coordinator_host.endpoint) - self.assertEqual(coordinator, "127.0.0.1:9042") + assert coordinator == "127.0.0.1:9042" self.node1.stop(wait_other_notice=True, gently=True) for i in range (10): bound = prepared.bind([i]) results = self.session.execute(bound) - self.assertEqual(results, [(i, i%5, i%2)]) + assert results == [(i, i%5, i%2)] coordinator = str(results.response_future.coordinator_host.endpoint) - self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"])) + assert coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"]) diff --git a/tests/integration/standard/test_rate_limit_exceeded.py b/tests/integration/standard/test_rate_limit_exceeded.py index 280d6426e1..211f0c9930 100644 --- a/tests/integration/standard/test_rate_limit_exceeded.py +++ b/tests/integration/standard/test_rate_limit_exceeded.py @@ -5,6 +5,7 @@ from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster +import pytest LOGGER = logging.getLogger(__name__) @@ -31,17 +32,17 @@ def test_rate_limit_exceeded(self): ) self.session.execute( """ - CREATE KEYSPACE IF NOT EXISTS ratetests + CREATE KEYSPACE IF NOT EXISTS ratetests WITH REPLICATION = {'class' : 'SimpleStrategy', 'replication_factor' : 1} """) self.session.execute("USE ratetests") self.session.execute( """ - CREATE TABLE tbl (pk int PRIMARY KEY, v int) + CREATE TABLE tbl (pk int PRIMARY KEY, v int) WITH per_partition_rate_limit = {'max_writes_per_second': 1} """) - + prepared = self.session.prepare( """ INSERT INTO tbl (pk, v) VALUES (?, ?) @@ -53,7 +54,7 @@ def execute_write(): for _ in range(1000): self.session.execute(prepared.bind((123, 456))) - with self.assertRaises(RateLimitReached) as context: + with pytest.raises(RateLimitReached) as context: execute_write() - self.assertEqual(context.exception.op_type, OperationType.Write) + assert context.value.op_type == OperationType.Write diff --git a/tests/integration/standard/test_routing.py b/tests/integration/standard/test_routing.py index 7d6651cf8b..ae45e7ade4 100644 --- a/tests/integration/standard/test_routing.py +++ b/tests/integration/standard/test_routing.py @@ -50,7 +50,7 @@ def insert_select_token(self, insert, select, key_values): cass_token = s.execute(select, key_values).one()[0] token = s.cluster.metadata.token_map.token_class(cass_token) - self.assertEqual(my_token, token) + assert my_token == token def create_prepare(self, key_types): s = self.session diff --git a/tests/integration/standard/test_row_factories.py b/tests/integration/standard/test_row_factories.py index 413a6bf50b..187f35704a 100644 --- a/tests/integration/standard/test_row_factories.py +++ b/tests/integration/standard/test_row_factories.py @@ -16,6 +16,7 @@ BasicSharedKeyspaceUnitTestCaseWFunctionTable, BasicSharedKeyspaceUnitTestCase, execute_until_pass, TestCluster import unittest +import pytest from cassandra.cluster import ResultSet, ExecutionProfile, EXEC_PROFILE_DEFAULT from cassandra.query import tuple_factory, named_tuple_factory, dict_factory, ordered_dict_factory @@ -66,9 +67,9 @@ def test_sanitizing(self): query = "SELECT v1 AS duplicate, v2 AS duplicate, v3 AS duplicate from {0}.{1}".format(self.ks_name, self.function_table_name) rs = self.session.execute(query) row = rs.one() - self.assertTrue(hasattr(row, 'duplicate')) - self.assertTrue(hasattr(row, 'duplicate_')) - self.assertTrue(hasattr(row, 'duplicate__')) + assert hasattr(row, 'duplicate') + assert hasattr(row, 'duplicate_') + assert hasattr(row, 'duplicate__') class RowFactoryTests(BasicSharedKeyspaceUnitTestCaseWFunctionTable): @@ -92,44 +93,44 @@ def _results_from_row_factory(self, row_factory): def test_tuple_factory(self): result = self._results_from_row_factory(tuple_factory) - self.assertIsInstance(result, ResultSet) - self.assertIsInstance(result.one(), tuple) + assert isinstance(result, ResultSet) + assert isinstance(result.one(), tuple) result = result.all() for row in result: - self.assertEqual(row[0], row[1]) + assert row[0] == row[1] - self.assertEqual(result[0][0], result[0][1]) - self.assertEqual(result[0][0], 1) - self.assertEqual(result[1][0], result[1][1]) - self.assertEqual(result[1][0], 2) + assert result[0][0] == result[0][1] + assert result[0][0] == 1 + assert result[1][0] == result[1][1] + assert result[1][0] == 2 def test_named_tuple_factory(self): result = self._results_from_row_factory(named_tuple_factory) - self.assertIsInstance(result, ResultSet) + assert isinstance(result, ResultSet) result = result.all() for row in result: - self.assertEqual(row.k, row.v) + assert row.k == row.v - self.assertEqual(result[0].k, result[0].v) - self.assertEqual(result[0].k, 1) - self.assertEqual(result[1].k, result[1].v) - self.assertEqual(result[1].k, 2) + assert result[0].k == result[0].v + assert result[0].k == 1 + assert result[1].k == result[1].v + assert result[1].k == 2 def _test_dict_factory(self, row_factory, row_type): result = self._results_from_row_factory(row_factory) - self.assertIsInstance(result, ResultSet) - self.assertIsInstance(result.one(), row_type) + assert isinstance(result, ResultSet) + assert isinstance(result.one(), row_type) result = result.all() for row in result: - self.assertEqual(row['k'], row['v']) + assert row['k'] == row['v'] - self.assertEqual(result[0]['k'], result[0]['v']) - self.assertEqual(result[0]['k'], 1) - self.assertEqual(result[1]['k'], result[1]['v']) - self.assertEqual(result[1]['k'], 2) + assert result[0]['k'] == result[0]['v'] + assert result[0]['k'] == 1 + assert result[1]['k'] == result[1]['v'] + assert result[1]['k'] == 2 def test_dict_factory(self): self._test_dict_factory(dict_factory, dict) @@ -164,9 +165,9 @@ def _gen_row_factory(rows): ( 1 , 1 ) '''.format(self.keyspace_name, self.function_table_name)) result = session.execute(self.select) - self.assertIsInstance(result, ResultSet) + assert isinstance(result, ResultSet) first_row = result.one() - self.assertEqual(first_row[0], first_row[1]) + assert first_row[0] == first_row[1] class NamedTupleFactoryAndNumericColNamesTests(unittest.TestCase): @@ -194,7 +195,7 @@ def test_no_exception_on_select(self): try: self.session.execute('SELECT * FROM test1rf.table_num_col') except ValueError as e: - self.fail("Unexpected ValueError exception: %s" % e.message) + pytest.fail("Unexpected ValueError exception: %s" % e.message) def test_can_select_using_alias(self): """ @@ -206,7 +207,7 @@ def test_can_select_using_alias(self): try: self.session.execute('SELECT key, "626972746864617465" AS my_col from test1rf.table_num_col') except ValueError as e: - self.fail("Unexpected ValueError exception: %s" % e.message) + pytest.fail("Unexpected ValueError exception: %s" % e.message) def test_can_select_with_dict_factory(self): """ @@ -218,4 +219,4 @@ def test_can_select_with_dict_factory(self): try: cluster.connect().execute('SELECT * FROM test1rf.table_num_col') except ValueError as e: - self.fail("Unexpected ValueError exception: %s" % e.message) + pytest.fail("Unexpected ValueError exception: %s" % e.message) diff --git a/tests/integration/standard/test_shard_aware.py b/tests/integration/standard/test_shard_aware.py index cf8f17e209..215c69c3b4 100644 --- a/tests/integration/standard/test_shard_aware.py +++ b/tests/integration/standard/test_shard_aware.py @@ -62,8 +62,8 @@ def verify_same_shard_in_tracing(self, results, shard_name): for event in events: LOGGER.info("%s %s %s", event.source, event.thread_name, event.description) for event in events: - self.assertIn(shard_name, event.thread_name) - self.assertIn('querying locally', "\n".join([event.description for event in events])) + assert shard_name in event.thread_name + assert 'querying locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) @@ -71,8 +71,8 @@ def verify_same_shard_in_tracing(self, results, shard_name): for event in events: LOGGER.info("%s %s", event.thread, event.activity) for event in events: - self.assertIn(shard_name, event.thread) - self.assertIn('querying locally', "\n".join([event.activity for event in events])) + assert shard_name in event.thread + assert 'querying locally' in "\n".join([event.activity for event in events]) def create_ks_and_cf(self): self.session.execute( @@ -120,13 +120,13 @@ def query_data(self, session, verify_in_tracing=True): bound = prepared.bind(('a', 'b')) results = session.execute(bound, trace=True) - self.assertEqual(results, [('a', 'b', 'c')]) + assert results == [('a', 'b', 'c')] if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 0") bound = prepared.bind(('100002', 'f')) results = session.execute(bound, trace=True) - self.assertEqual(results, [('100002', 'f', 'g')]) + assert results == [('100002', 'f', 'g')] if verify_in_tracing: self.verify_same_shard_in_tracing(results, "shard 1") diff --git a/tests/integration/standard/test_single_interface.py b/tests/integration/standard/test_single_interface.py index 681e992477..3fd90b9708 100644 --- a/tests/integration/standard/test_single_interface.py +++ b/tests/integration/standard/test_single_interface.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import pytest from cassandra import ConsistencyLevel from cassandra.query import SimpleStatement @@ -52,17 +53,15 @@ def test_single_interface(self): hosts = self.cluster.metadata._hosts broadcast_rpc_ports = [] broadcast_ports = [] - self.assertEqual(len(hosts), 3) + assert len(hosts) == 3 for endpoint, host in hosts.items(): - self.assertEqual(endpoint.address, host.broadcast_rpc_address) - self.assertEqual(endpoint.port, host.broadcast_rpc_port) + assert endpoint.address == host.broadcast_rpc_address + assert endpoint.port == host.broadcast_rpc_port - if host.broadcast_rpc_port in broadcast_rpc_ports: - self.fail("Duplicate broadcast_rpc_port") + assert host.broadcast_rpc_port not in broadcast_rpc_ports, "Duplicate broadcast_rpc_port" broadcast_rpc_ports.append(host.broadcast_rpc_port) - if host.broadcast_port in broadcast_ports: - self.fail("Duplicate broadcast_port") + assert host.broadcast_port not in broadcast_ports, "Duplicate broadcast_port" broadcast_ports.append(host.broadcast_port) for _ in range(1, 100): @@ -70,4 +69,4 @@ def test_single_interface(self): consistency_level=ConsistencyLevel.ALL)) for pool in self.session.get_pools(): - self.assertEqual(1, pool.get_state()['open_count']) + assert 1 == pool.get_state()['open_count'] diff --git a/tests/integration/standard/test_tablets.py b/tests/integration/standard/test_tablets.py index 79dd166603..0216f7843a 100644 --- a/tests/integration/standard/test_tablets.py +++ b/tests/integration/standard/test_tablets.py @@ -5,15 +5,12 @@ from cassandra.cluster import Cluster from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy -from tests.integration import PROTOCOL_VERSION, use_cluster +from tests.integration import PROTOCOL_VERSION, use_cluster, get_cluster from tests.unit.test_host_connection_pool import LOGGER -CCM_CLUSTER = None def setup_module(): - global CCM_CLUSTER - - CCM_CLUSTER = use_cluster('tablets', [3], start=True) + use_cluster('tablets', [3], start=True) class TestTabletsIntegration: @@ -193,7 +190,7 @@ def drop_ks(_): def test_tablets_invalidation_decommission_non_cc_node(self): def decommission_non_cc_node(rec): # Drop and recreate ks and table to trigger tablets invalidation - for node in CCM_CLUSTER.nodes.values(): + for node in get_cluster().nodes.values(): if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]: # Ignore node that control connection is connected to continue diff --git a/tests/integration/standard/test_types.py b/tests/integration/standard/test_types.py index eb50c7780a..4ee9b70cde 100644 --- a/tests/integration/standard/test_types.py +++ b/tests/integration/standard/test_types.py @@ -36,12 +36,14 @@ from cassandra.query import dict_factory, ordered_dict_factory from cassandra.util import sortedset, Duration, OrderedMap from tests.unit.cython.utils import cythontest +from tests.util import assertEqual from tests.integration import use_singledc, execute_until_pass, notprotocolv1, \ BasicSharedKeyspaceUnitTestCase, greaterthancass21, lessthancass30, \ greaterthanorequalcass3_10, TestCluster, requires_composite_type, greaterthanorequalcass50 from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, COLLECTION_TYPES, PRIMITIVE_DATATYPES_KEYS, \ get_sample, get_all_samples, get_collection_sample +import pytest def setup_module(): @@ -72,7 +74,7 @@ def test_can_insert_blob_type_as_string(self): results = s.execute("SELECT * FROM blobstring").one() for expected, actual in zip(params, results): - self.assertEqual(expected, actual) + assert expected == actual def test_can_insert_blob_type_as_bytearray(self): """ @@ -87,7 +89,7 @@ def test_can_insert_blob_type_as_bytearray(self): results = s.execute("SELECT * FROM blobbytes").one() for expected, actual in zip(params, results): - self.assertEqual(expected, actual) + assert expected == actual @unittest.skipIf(not hasattr(cassandra, 'deserializers'), "Cython required for to test DesBytesTypeArray deserializer") def test_des_bytes_type_array(self): @@ -114,7 +116,7 @@ def test_des_bytes_type_array(self): results = s.execute("SELECT * FROM blobbytes2").one() for expected, actual in zip(params, results): - self.assertEqual(expected, actual) + assert expected == actual finally: if original is not None: cassandra.deserializers.DesBytesType=original @@ -149,7 +151,7 @@ def test_can_insert_primitive_datatypes(self): # verify data results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # try the same thing sending one insert at the time s.execute("TRUNCATE alltypes;") @@ -169,7 +171,7 @@ def test_can_insert_primitive_datatypes(self): if isinstance(data_sample, ipaddress.IPv4Address) or isinstance(data_sample, ipaddress.IPv6Address): compare_value = str(data_sample) - self.assertEqual(result, compare_value) + assert result == compare_value # try the same thing with a prepared statement placeholders = ','.join(["?"] * len(col_names)) @@ -180,13 +182,13 @@ def test_can_insert_primitive_datatypes(self): # verify data results = s.execute("SELECT {0} FROM alltypes WHERE zz=0".format(columns_string)).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # verify data with prepared statement query select = s.prepare("SELECT {0} FROM alltypes WHERE zz=?".format(columns_string)) results = s.execute(select.bind([0])).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM alltypes") @@ -194,7 +196,7 @@ def test_can_insert_primitive_datatypes(self): execution_profile=s.execution_profile_clone_update(EXEC_PROFILE_DEFAULT, row_factory=ordered_dict_factory)).one() for expected, actual in zip(params, results.values()): - self.assertEqual(actual, expected) + assert actual == expected c.shutdown() @@ -242,7 +244,7 @@ def test_can_insert_collection_datatypes(self): # verify data results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # create the input for prepared statement params = [0] @@ -258,13 +260,13 @@ def test_can_insert_collection_datatypes(self): # verify data results = s.execute("SELECT {0} FROM allcoltypes WHERE zz=0".format(columns_string)).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # verify data with prepared statement query select = s.prepare("SELECT {0} FROM allcoltypes WHERE zz=?".format(columns_string)) results = s.execute(select.bind([0])).one() for expected, actual in zip(params, results): - self.assertEqual(actual, expected) + assert actual == expected # verify data with with prepared statement, use dictionary with no explicit columns select = s.prepare("SELECT * FROM allcoltypes") @@ -273,7 +275,7 @@ def test_can_insert_collection_datatypes(self): row_factory=ordered_dict_factory)).one() for expected, actual in zip(params, results.values()): - self.assertEqual(actual, expected) + assert actual == expected c.shutdown() @@ -307,12 +309,12 @@ def test_can_insert_empty_strings_and_nulls(self): columns_string = ','.join(col_names) s.execute("INSERT INTO all_empty (zz) VALUES (2)") results = s.execute("SELECT {0} FROM all_empty WHERE zz=2".format(columns_string)).one() - self.assertTrue(all(x is None for x in results)) + assert all(x is None for x in results) # verify all types initially null with prepared statement select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) results = s.execute(select.bind([2])).one() - self.assertTrue(all(x is None for x in results)) + assert all(x is None for x in results) # insert empty strings for string-like fields expected_values = dict((col, '') for col in string_columns) @@ -323,21 +325,21 @@ def test_can_insert_empty_strings_and_nulls(self): # verify string types empty with simple statement results = s.execute("SELECT {0} FROM all_empty WHERE zz=3".format(columns_string)).one() for expected, actual in zip(expected_values.values(), results): - self.assertEqual(actual, expected) + assert actual == expected # verify string types empty with prepared statement results = s.execute(s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)), [3]).one() for expected, actual in zip(expected_values.values(), results): - self.assertEqual(actual, expected) + assert actual == expected # non-string types shouldn't accept empty strings for col in non_string_columns: query = "INSERT INTO all_empty (zz, {0}) VALUES (4, %s)".format(col) - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): s.execute(query, ['']) insert = s.prepare("INSERT INTO all_empty (zz, {0}) VALUES (4, ?)".format(col)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): s.execute(insert, ['']) # verify that Nones can be inserted and overwrites existing data @@ -360,13 +362,13 @@ def test_can_insert_empty_strings_and_nulls(self): query = "SELECT {0} FROM all_empty WHERE zz=5".format(columns_string) results = s.execute(query).one() for col in results: - self.assertEqual(None, col) + assert None == col # check via prepared statement select = s.prepare("SELECT {0} FROM all_empty WHERE zz=?".format(columns_string)) results = s.execute(select.bind([5])).one() for col in results: - self.assertEqual(None, col) + assert None == col # do the same thing again, but use a prepared statement to insert the nulls s.execute(simple_insert, params) @@ -377,11 +379,11 @@ def test_can_insert_empty_strings_and_nulls(self): results = s.execute(query).one() for col in results: - self.assertEqual(None, col) + assert None == col results = s.execute(select.bind([5])).one() for col in results: - self.assertEqual(None, col) + assert None == col def test_can_insert_empty_values_for_int32(self): """ @@ -394,7 +396,7 @@ def test_can_insert_empty_values_for_int32(self): try: Int32Type.support_empty_values = True results = execute_until_pass(s, "SELECT b FROM empty_values WHERE a='a'").one() - self.assertIs(EMPTY, results.b) + assert EMPTY is results.b finally: Int32Type.support_empty_values = False @@ -419,13 +421,13 @@ def test_timezone_aware_datetimes_are_timestamps(self): # test non-prepared statement s.execute("INSERT INTO tz_aware (a, b) VALUES ('key1', %s)", [dt]) result = s.execute("SELECT b FROM tz_aware WHERE a='key1'").one().b - self.assertEqual(dt.utctimetuple(), result.utctimetuple()) + assert dt.utctimetuple() == result.utctimetuple() # test prepared statement insert = s.prepare("INSERT INTO tz_aware (a, b) VALUES ('key2', ?)") s.execute(insert.bind([dt])) result = s.execute("SELECT b FROM tz_aware WHERE a='key2'").one().b - self.assertEqual(dt.utctimetuple(), result.utctimetuple()) + assert dt.utctimetuple() == result.utctimetuple() def test_can_insert_tuples(self): """ @@ -447,20 +449,20 @@ def test_can_insert_tuples(self): complete = ('foo', 123, True) s.execute("INSERT INTO tuple_type (a, b) VALUES (0, %s)", parameters=(complete,)) result = s.execute("SELECT b FROM tuple_type WHERE a=0").one() - self.assertEqual(complete, result.b) + assert complete == result.b partial = ('bar', 456) partial_result = partial + (None,) s.execute("INSERT INTO tuple_type (a, b) VALUES (1, %s)", parameters=(partial,)) result = s.execute("SELECT b FROM tuple_type WHERE a=1").one() - self.assertEqual(partial_result, result.b) + assert partial_result == result.b # test single value tuples subpartial = ('zoo',) subpartial_result = subpartial + (None, None) s.execute("INSERT INTO tuple_type (a, b) VALUES (2, %s)", parameters=(subpartial,)) result = s.execute("SELECT b FROM tuple_type WHERE a=2").one() - self.assertEqual(subpartial_result, result.b) + assert subpartial_result == result.b # test prepared statement prepared = s.prepare("INSERT INTO tuple_type (a, b) VALUES (?, ?)") @@ -469,12 +471,13 @@ def test_can_insert_tuples(self): s.execute(prepared, parameters=(5, subpartial)) # extra items in the tuple should result in an error - self.assertRaises(ValueError, s.execute, prepared, parameters=(0, (1, 2, 3, 4, 5, 6))) + with pytest.raises(ValueError): + s.execute(prepared, parameters=(0, (1, 2, 3, 4, 5, 6))) prepared = s.prepare("SELECT b FROM tuple_type WHERE a=?") - self.assertEqual(complete, s.execute(prepared, (3,)).one().b) - self.assertEqual(partial_result, s.execute(prepared, (4,)).one().b) - self.assertEqual(subpartial_result, s.execute(prepared, (5,)).one().b) + assert complete == s.execute(prepared, (3,)).one().b + assert partial_result == s.execute(prepared, (4,)).one().b + assert subpartial_result == s.execute(prepared, (5,)).one().b c.shutdown() @@ -507,7 +510,8 @@ def test_can_insert_tuples_with_varying_lengths(self): for i in lengths: # ensure tuples of larger sizes throw an error created_tuple = tuple(range(0, i + 1)) - self.assertRaises(InvalidRequest, s.execute, "INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) + with pytest.raises(InvalidRequest): + s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) # ensure tuples of proper sizes are written and read correctly created_tuple = tuple(range(0, i)) @@ -515,7 +519,7 @@ def test_can_insert_tuples_with_varying_lengths(self): s.execute("INSERT INTO tuple_lengths (k, v_%s) VALUES (0, %s)", (i, created_tuple)) result = s.execute("SELECT v_%s FROM tuple_lengths WHERE k=0", (i,)).one() - self.assertEqual(tuple(created_tuple), result['v_%s' % i]) + assert tuple(created_tuple) == result['v_%s' % i] c.shutdown() def test_can_insert_tuples_all_primitive_datatypes(self): @@ -543,7 +547,7 @@ def test_can_insert_tuples_all_primitive_datatypes(self): expected = tuple(values + [None] * (type_count - len(values))) s.execute("INSERT INTO tuple_primitive (k, v) VALUES (%s, %s)", (i, tuple(values))) result = s.execute("SELECT v FROM tuple_primitive WHERE k=%s", (i,)).one() - self.assertEqual(result.v, expected) + assert result.v == expected c.shutdown() def test_can_insert_tuples_all_collection_datatypes(self): @@ -598,7 +602,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - self.assertEqual(created_tuple, result['v_%s' % i]) + assert created_tuple == result['v_%s' % i] i += 1 # test tuple> @@ -607,7 +611,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - self.assertEqual(created_tuple, result['v_%s' % i]) + assert created_tuple == result['v_%s' % i] i += 1 # test tuple> @@ -621,7 +625,7 @@ def test_can_insert_tuples_all_collection_datatypes(self): s.execute("INSERT INTO tuple_non_primative (k, v_%s) VALUES (0, %s)", (i, created_tuple)) result = s.execute("SELECT v_%s FROM tuple_non_primative WHERE k=0", (i,)).one() - self.assertEqual(created_tuple, result['v_%s' % i]) + assert created_tuple == result['v_%s' % i] i += 1 c.shutdown() @@ -682,7 +686,7 @@ def test_can_insert_nested_tuples(self): # verify tuple was written and read correctly result = s.execute("SELECT v_%s FROM nested_tuples WHERE k=%s", (i, i)).one() - self.assertEqual(created_tuple, result['v_%s' % i]) + assert created_tuple == result['v_%s' % i] c.shutdown() def test_can_insert_tuples_with_nulls(self): @@ -701,16 +705,16 @@ def test_can_insert_tuples_with_nulls(self): s.execute(insert, [(None, None, None, None)]) result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") - self.assertEqual((None, None, None, None), result.one().t) + assert (None, None, None, None) == result.one().t read = s.prepare("SELECT * FROM tuples_nulls WHERE k=0") - self.assertEqual((None, None, None, None), s.execute(read).one().t) + assert (None, None, None, None) == s.execute(read).one().t # also test empty strings where compatible s.execute(insert, [('', None, None, b'')]) result = s.execute("SELECT * FROM tuples_nulls WHERE k=0") - self.assertEqual(('', None, None, b''), result.one().t) - self.assertEqual(('', None, None, b''), s.execute(read).one().t) + assert ('', None, None, b'') == result.one().t + assert ('', None, None, b'') == s.execute(read).one().t def test_insert_collection_with_null_fails(self): """ @@ -729,9 +733,11 @@ def test_insert_collection_with_null_fails(self): s.execute(f'CREATE TABLE collection_nulls (k int PRIMARY KEY, {", ".join(columns)})') def raises_simple_and_prepared(exc_type, query_str, args): - self.assertRaises(exc_type, lambda: s.execute(query_str, args)) + with pytest.raises(exc_type): + s.execute(query_str, args) p = s.prepare(query_str.replace('%s', '?')) - self.assertRaises(exc_type, lambda: s.execute(p, args)) + with pytest.raises(exc_type): + s.execute(p, args) i = 0 for simple_type in PRIMITIVE_DATATYPES_KEYS: @@ -781,14 +787,14 @@ def test_can_read_composite_type(self): # CompositeType string literals are split on ':' chars s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc:123')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() - self.assertEqual(0, result.a) - self.assertEqual(('abc', 123), result.b) + assert 0 == result.a + assert ('abc', 123) == result.b # CompositeType values can omit elements at the end s.execute("INSERT INTO composites (a, b) VALUES (0, 'abc')") result = s.execute("SELECT * FROM composites WHERE a = 0").one() - self.assertEqual(0, result.a) - self.assertEqual(('abc',), result.b) + assert 0 == result.a + assert ('abc',) == result.b @notprotocolv1 def test_special_float_cql_encoding(self): @@ -815,11 +821,11 @@ def verify_insert_select(ins_statement, sel_statement): for f in items: row = s.execute(sel_statement, (f,)).one() if math.isnan(f): - self.assertTrue(math.isnan(row.f)) - self.assertTrue(math.isnan(row.d)) + assert math.isnan(row.f) + assert math.isnan(row.d) else: - self.assertEqual(row.f, f) - self.assertEqual(row.d, f) + assert row.f == f + assert row.d == f # cql encoding verify_insert_select('INSERT INTO float_cql_encoding (f, d) VALUES (%s, %s)', @@ -847,7 +853,7 @@ def test_cython_decimal(self): try: self.session.execute("INSERT INTO {0} (dc) VALUES (-1.08430792318105707)".format(self.function_table_name)) results = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) - self.assertTrue(str(results.one().dc) == '-1.08430792318105707') + assert str(results.one().dc) == '-1.08430792318105707' finally: self.session.execute("DROP TABLE {0}".format(self.function_table_name)) @@ -891,14 +897,16 @@ def test_smoke_duration_values(self): results = self.session.execute("SELECT * FROM duration_smoke") v = results.one()[1] - self.assertEqual(Duration(month_day_value, month_day_value, nanosecond_value), v, - "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value)) + assert Duration(month_day_value, month_day_value, nanosecond_value) == v, "Error encoding value {0},{0},{1}".format(month_day_value, nanosecond_value) - self.assertRaises(ValueError, self.session.execute, prepared, + with pytest.raises(ValueError): + self.session.execute(prepared, (1, Duration(0, 0, int("8FFFFFFFFFFFFFF0", 16)))) - self.assertRaises(ValueError, self.session.execute, prepared, + with pytest.raises(ValueError): + self.session.execute(prepared, (1, Duration(0, int("8FFFFFFFFFFFFFF0", 16), 0))) - self.assertRaises(ValueError, self.session.execute, prepared, + with pytest.raises(ValueError): + self.session.execute(prepared, (1, Duration(int("8FFFFFFFFFFFFFF0", 16), 0, 0))) class TypeTestsProtocol(BasicSharedKeyspaceUnitTestCase): @@ -948,16 +956,16 @@ def read_inserts_at_level(self, proto_ver): session = TestCluster(protocol_version=proto_ver).connect(self.keyspace_name) try: results = session.execute('select * from t').one() - self.assertEqual("[SortedSet([1, 2]), SortedSet([3, 5])]", str(results.v)) + assert "[SortedSet([1, 2]), SortedSet([3, 5])]" == str(results.v) results = session.execute('select * from u').one() - self.assertEqual("SortedSet([[1, 2], [3, 5]])", str(results.v)) + assert "SortedSet([[1, 2], [3, 5]])" == str(results.v) results = session.execute('select * from v').one() - self.assertEqual("{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}", str(results.v)) + assert "{SortedSet([1, 2]): [1, 2, 3], SortedSet([3, 5]): [4, 5, 6]}" == str(results.v) results = session.execute('select * from w').one() - self.assertEqual("typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])", str(results.v)) + assert "typ(v0=OrderedMapSerializedKey([(1, [1, 2, 3]), (2, [4, 5, 6])]), v1=[7, 8, 9])" == str(results.v) finally: session.cluster.shutdown() @@ -985,7 +993,7 @@ class TypeTestsVector(BasicSharedKeyspaceUnitTestCase): def _get_first_j(self, rs): rows = rs.all() - self.assertEqual(len(rows), 1) + assert len(rows) == 1 return rows[0].j def _get_row_simple(self, idx, table_name): @@ -1037,14 +1045,14 @@ def random_subtype_vector(): test_fn(observed2[idx], expected2[idx]) def test_round_trip_integers(self): - self._round_trip_test("int", partial(random.randint, 0, 2 ** 31), self.assertEqual) - self._round_trip_test("bigint", partial(random.randint, 0, 2 ** 63), self.assertEqual) - self._round_trip_test("smallint", partial(random.randint, 0, 2 ** 15), self.assertEqual) - self._round_trip_test("tinyint", partial(random.randint, 0, (2 ** 7) - 1), self.assertEqual) - self._round_trip_test("varint", partial(random.randint, 0, 2 ** 63), self.assertEqual) + self._round_trip_test("int", partial(random.randint, 0, 2 ** 31), assertEqual) + self._round_trip_test("bigint", partial(random.randint, 0, 2 ** 63), assertEqual) + self._round_trip_test("smallint", partial(random.randint, 0, 2 ** 15), assertEqual) + self._round_trip_test("tinyint", partial(random.randint, 0, (2 ** 7) - 1), assertEqual) + self._round_trip_test("varint", partial(random.randint, 0, 2 ** 63), assertEqual) def test_round_trip_floating_point(self): - _almost_equal_test_fn = partial(self.assertAlmostEqual, places=5) + _almost_equal_test_fn = partial(pytest.approx, abs=1e-5) def _random_decimal(): return Decimal(random.uniform(0.0, 100.0)) @@ -1058,11 +1066,11 @@ def test_round_trip_text(self): def _random_string(): return ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(24)) - self._round_trip_test("ascii", _random_string, self.assertEqual) - self._round_trip_test("text", _random_string, self.assertEqual) + self._round_trip_test("ascii", _random_string, assertEqual) + self._round_trip_test("text", _random_string, assertEqual) def test_round_trip_date_and_time(self): - _almost_equal_test_fn = partial(self.assertAlmostEqual, delta=timedelta(seconds=1)) + _almost_equal_test_fn = partial(pytest.approx, abs=timedelta(seconds=1)) def _random_datetime(): return datetime.today() - timedelta(hours=random.randint(0,18), days=random.randint(1,1000)) def _random_date(): @@ -1070,13 +1078,13 @@ def _random_date(): def _random_time(): return _random_datetime().time() - self._round_trip_test("date", _random_date, self.assertEqual) - self._round_trip_test("time", _random_time, self.assertEqual) + self._round_trip_test("date", _random_date, assertEqual) + self._round_trip_test("time", _random_time, assertEqual) self._round_trip_test("timestamp", _random_datetime, _almost_equal_test_fn) def test_round_trip_uuid(self): - self._round_trip_test("uuid", uuid.uuid1, self.assertEqual) - self._round_trip_test("timeuuid", uuid.uuid1, self.assertEqual) + self._round_trip_test("uuid", uuid.uuid1, assertEqual) + self._round_trip_test("timeuuid", uuid.uuid1, assertEqual) def test_round_trip_miscellany(self): def _random_bytes(): @@ -1088,10 +1096,10 @@ def _random_duration(): def _random_inet(): return socket.inet_ntoa(_random_bytes()) - self._round_trip_test("boolean", _random_boolean, self.assertEqual) - self._round_trip_test("duration", _random_duration, self.assertEqual) - self._round_trip_test("inet", _random_inet, self.assertEqual) - self._round_trip_test("blob", _random_bytes, self.assertEqual) + self._round_trip_test("boolean", _random_boolean, assertEqual) + self._round_trip_test("duration", _random_duration, assertEqual) + self._round_trip_test("inet", _random_inet, assertEqual) + self._round_trip_test("blob", _random_bytes, assertEqual) def test_round_trip_collections(self): def _random_seq(): @@ -1102,21 +1110,21 @@ def _random_map(): return {k:v for (k,v) in zip(_random_seq(), _random_seq())} # Goal here is to test collections of both fixed and variable size subtypes - self._round_trip_test("list", _random_seq, self.assertEqual) - self._round_trip_test("list", _random_seq, self.assertEqual) - self._round_trip_test("set", _random_set, self.assertEqual) - self._round_trip_test("set", _random_set, self.assertEqual) - self._round_trip_test("map", _random_map, self.assertEqual) - self._round_trip_test("map", _random_map, self.assertEqual) - self._round_trip_test("map", _random_map, self.assertEqual) - self._round_trip_test("map", _random_map, self.assertEqual) + self._round_trip_test("list", _random_seq, assertEqual) + self._round_trip_test("list", _random_seq, assertEqual) + self._round_trip_test("set", _random_set, assertEqual) + self._round_trip_test("set", _random_set, assertEqual) + self._round_trip_test("map", _random_map, assertEqual) + self._round_trip_test("map", _random_map, assertEqual) + self._round_trip_test("map", _random_map, assertEqual) + self._round_trip_test("map", _random_map, assertEqual) def test_round_trip_vector_of_vectors(self): def _random_vector(): return [random.randint(0,100000) for _ in range(2)] - self._round_trip_test("vector", _random_vector, self.assertEqual) - self._round_trip_test("vector", _random_vector, self.assertEqual) + self._round_trip_test("vector", _random_vector, assertEqual) + self._round_trip_test("vector", _random_vector, assertEqual) def test_round_trip_tuples(self): def _random_tuple(): @@ -1124,15 +1132,15 @@ def _random_tuple(): # Unfortunately we can't use positional parameters when inserting tuples because the driver will try to encode # them as lists before sending them to the server... and that confuses the parsing logic. - self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) - self._round_trip_test("tuple", _random_tuple, self.assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) + self._round_trip_test("tuple", _random_tuple, assertEqual, use_positional_parameters=False) def test_round_trip_udts(self): def _udt_equal_test_fn(udt1, udt2): - self.assertEqual(udt1.a, udt2.a) - self.assertEqual(udt1.b, udt2.b) + assert udt1.a == udt2.a + assert udt1.b == udt2.b self.session.execute("create type {}.fixed_type (a int, b int)".format(self.keyspace_name)) self.session.execute("create type {}.mixed_type_one (a int, b varint)".format(self.keyspace_name)) diff --git a/tests/integration/standard/test_udts.py b/tests/integration/standard/test_udts.py index 7188bf3eb8..dd696ea0e9 100644 --- a/tests/integration/standard/test_udts.py +++ b/tests/integration/standard/test_udts.py @@ -25,6 +25,7 @@ BasicSegregatedKeyspaceUnitTestCase, greaterthancass20, lessthancass30, greaterthanorequalcass36, TestCluster from tests.integration.datatype_utils import update_datatypes, PRIMITIVE_DATATYPES, PRIMITIVE_DATATYPES_KEYS, \ COLLECTION_TYPES, get_sample, get_collection_sample +import pytest nested_collection_udt = namedtuple('nested_collection_udt', ['m', 't', 'l', 's']) nested_collection_udt_nested = namedtuple('nested_collection_udt_nested', ['m', 't', 'l', 's', 'u']) @@ -65,9 +66,9 @@ def test_non_frozen_udts(self): self.session.execute("INSERT INTO {0} (a, b) VALUES (%s, %s)".format(self.function_table_name), (0, User("Nebraska", True))) self.session.execute("UPDATE {0} SET b.has_corn = False where a = 0".format(self.function_table_name)) result = self.session.execute("SELECT * FROM {0}".format(self.function_table_name)) - self.assertFalse(result.one().b.has_corn) + assert not result.one().b.has_corn table_sql = self.cluster.metadata.keyspaces[self.keyspace_name].tables[self.function_table_name].as_cql_query() - self.assertNotIn("", table_sql) + assert "" not in table_sql def test_can_insert_unprepared_registered_udts(self): """ @@ -86,9 +87,9 @@ def test_can_insert_unprepared_registered_udts(self): s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User(42, 'bob'))) result = s.execute("SELECT b FROM mytable WHERE a=0") row = result.one() - self.assertEqual(42, row.b.age) - self.assertEqual('bob', row.b.name) - self.assertTrue(type(row.b) is User) + assert 42 == row.b.age + assert 'bob' == row.b.name + assert type(row.b) is User # use the same UDT name in a different keyspace s.execute(""" @@ -105,9 +106,9 @@ def test_can_insert_unprepared_registered_udts(self): s.execute("INSERT INTO mytable (a, b) VALUES (%s, %s)", (0, User('Texas', True))) result = s.execute("SELECT b FROM mytable WHERE a=0") row = result.one() - self.assertEqual('Texas', row.b.state) - self.assertEqual(True, row.b.is_cool) - self.assertTrue(type(row.b) is User) + assert 'Texas' == row.b.state + assert True == row.b.is_cool + assert type(row.b) is User s.execute("DROP KEYSPACE udt_test_unprepared_registered2") @@ -151,17 +152,17 @@ def test_can_register_udt_before_connecting(self): s.execute("INSERT INTO udt_test_register_before_connecting.mytable (a, b) VALUES (%s, %s)", (0, User1(42, 'bob'))) result = s.execute("SELECT b FROM udt_test_register_before_connecting.mytable WHERE a=0") row = result.one() - self.assertEqual(42, row.b.age) - self.assertEqual('bob', row.b.name) - self.assertTrue(type(row.b) is User1) + assert 42 == row.b.age + assert 'bob' == row.b.name + assert type(row.b) is User1 # use the same UDT name in a different keyspace s.execute("INSERT INTO udt_test_register_before_connecting2.mytable (a, b) VALUES (%s, %s)", (0, User2('Texas', True))) result = s.execute("SELECT b FROM udt_test_register_before_connecting2.mytable WHERE a=0") row = result.one() - self.assertEqual('Texas', row.b.state) - self.assertEqual(True, row.b.is_cool) - self.assertTrue(type(row.b) is User2) + assert 'Texas' == row.b.state + assert True == row.b.is_cool + assert type(row.b) is User2 s.execute("DROP KEYSPACE udt_test_register_before_connecting") s.execute("DROP KEYSPACE udt_test_register_before_connecting2") @@ -186,8 +187,8 @@ def test_can_insert_prepared_unregistered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - self.assertEqual(42, row.b.age) - self.assertEqual('bob', row.b.name) + assert 42 == row.b.age + assert 'bob' == row.b.name # use the same UDT name in a different keyspace s.execute(""" @@ -205,8 +206,8 @@ def test_can_insert_prepared_unregistered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - self.assertEqual('Texas', row.b.state) - self.assertEqual(True, row.b.is_cool) + assert 'Texas' == row.b.state + assert True == row.b.is_cool s.execute("DROP KEYSPACE udt_test_prepared_unregistered2") @@ -232,9 +233,9 @@ def test_can_insert_prepared_registered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - self.assertEqual(42, row.b.age) - self.assertEqual('bob', row.b.name) - self.assertTrue(type(row.b) is User) + assert 42 == row.b.age + assert 'bob' == row.b.name + assert type(row.b) is User # use the same UDT name in a different keyspace s.execute(""" @@ -254,9 +255,9 @@ def test_can_insert_prepared_registered_udts(self): select = s.prepare("SELECT b FROM mytable WHERE a=?") result = s.execute(select, (0,)) row = result.one() - self.assertEqual('Texas', row.b.state) - self.assertEqual(True, row.b.is_cool) - self.assertTrue(type(row.b) is User) + assert 'Texas' == row.b.state + assert True == row.b.is_cool + assert type(row.b) is User s.execute("DROP KEYSPACE udt_test_prepared_registered2") @@ -280,15 +281,15 @@ def test_can_insert_udts_with_nulls(self): s.execute(insert, [User(None, None, None, None)]) results = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual((None, None, None, None), results.one().b) + assert (None, None, None, None) == results.one().b select = s.prepare("SELECT b FROM mytable WHERE a=0") - self.assertEqual((None, None, None, None), s.execute(select).one().b) + assert (None, None, None, None) == s.execute(select).one().b # also test empty strings s.execute(insert, [User('', None, None, bytes())]) results = s.execute("SELECT b FROM mytable WHERE a=0") - self.assertEqual(('', None, None, bytes()), results.one().b) + assert ('', None, None, bytes()) == results.one().b c.shutdown() @@ -328,7 +329,7 @@ def test_can_insert_udts_with_varying_lengths(self): # verify udt was written and read correctly, increase timeout to avoid the query failure on slow systems result = s.execute("SELECT v FROM mytable WHERE k=0").one() - self.assertEqual(created_udt, result.v) + assert created_udt == result.v c.shutdown() @@ -366,7 +367,7 @@ def nested_udt_verification_helper(self, session, max_nesting_depth, udts): # verify udt was written and read correctly result = session.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i)).one() - self.assertEqual(udt, result["v_{0}".format(i)]) + assert udt == result["v_{0}".format(i)] # write udt via prepared statement insert = session.prepare("INSERT INTO mytable (k, v_{0}) VALUES (1, ?)".format(i)) @@ -374,7 +375,7 @@ def nested_udt_verification_helper(self, session, max_nesting_depth, udts): # verify udt was written and read correctly result = session.execute("SELECT v_{0} FROM mytable WHERE k=1".format(i)).one() - self.assertEqual(udt, result["v_{0}".format(i)]) + assert udt == result["v_{0}".format(i)] def _cluster_default_dict_factory(self): return TestCluster( @@ -442,7 +443,7 @@ def test_can_insert_nested_unregistered_udts(self): # verify udt was written and read correctly result = s.execute("SELECT v_{0} FROM mytable WHERE k=0".format(i)).one() - self.assertEqual(udt, result["v_{0}".format(i)]) + assert udt == result["v_{0}".format(i)] def test_can_insert_nested_registered_udts_with_different_namedtuples(self): """ @@ -482,13 +483,13 @@ def test_raise_error_on_nonexisting_udts(self): s = c.connect(self.keyspace_name, wait_for_all_pools=True) User = namedtuple('user', ('age', 'name')) - with self.assertRaises(UserTypeDoesNotExist): + with pytest.raises(UserTypeDoesNotExist): c.register_user_type("some_bad_keyspace", "user", User) - with self.assertRaises(UserTypeDoesNotExist): + with pytest.raises(UserTypeDoesNotExist): c.register_user_type("system", "user", User) - with self.assertRaises(InvalidRequest): + with pytest.raises(InvalidRequest): s.execute("CREATE TABLE mytable (a int PRIMARY KEY, b frozen)") c.shutdown() @@ -534,7 +535,7 @@ def test_can_insert_udt_all_datatypes(self): row = results.one().b for expected, actual in zip(params, row): - self.assertEqual(expected, actual) + assert expected == actual c.shutdown() @@ -592,7 +593,7 @@ def test_can_insert_udt_all_collection_datatypes(self): row = results.one().b for expected, actual in zip(params, row): - self.assertEqual(expected, actual) + assert expected == actual c.shutdown() @@ -600,7 +601,7 @@ def insert_select_column(self, session, table_name, column_name, value): insert = session.prepare("INSERT INTO %s (k, %s) VALUES (?, ?)" % (table_name, column_name)) session.execute(insert, (0, value)) result = session.execute("SELECT %s FROM %s WHERE k=%%s" % (column_name, table_name), (0,)).one()[0] - self.assertEqual(result, value) + assert result == value def test_can_insert_nested_collections(self): """ @@ -669,15 +670,15 @@ def test_non_alphanum_identifiers(self): row = s.execute('SELECT * FROM %s' % (self.table_name,)).one() k, v = row.non_alphanum_type_map.popitem() - self.assertEqual(v, 0) - self.assertEqual(k.__class__, tuple) - self.assertEqual(k[0], 'nonalphanum') + assert v == 0 + assert k.__class__ == tuple + assert k[0] == 'nonalphanum' k, v = row.alphanum_type_map.popitem() - self.assertEqual(v, 1) - self.assertNotEqual(k.__class__, tuple) # should be the namedtuple type - self.assertEqual(k[0], 'alphanum') - self.assertEqual(k.field_0_, 'alphanum') # named tuple with positional field name + assert v == 1 + assert k.__class__ != tuple # should be the namedtuple type + assert k[0] == 'alphanum' + assert k.field_0_ == 'alphanum' # named tuple with positional field name @lessthancass30 def test_type_alteration(self): @@ -686,9 +687,9 @@ def test_type_alteration(self): """ s = self.session type_name = "type_name" - self.assertNotIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) + assert type_name not in s.cluster.metadata.keyspaces['udttests'].user_types s.execute('CREATE TYPE %s (v0 int)' % (type_name,)) - self.assertIn(type_name, s.cluster.metadata.keyspaces['udttests'].user_types) + assert type_name in s.cluster.metadata.keyspaces['udttests'].user_types s.execute('CREATE TABLE %s (k int PRIMARY KEY, v frozen<%s>)' % (self.table_name, type_name)) s.execute('INSERT INTO %s (k, v) VALUES (0, {v0 : 1})' % (self.table_name,)) @@ -696,24 +697,24 @@ def test_type_alteration(self): s.cluster.register_user_type('udttests', type_name, dict) val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - self.assertEqual(val['v0'], 1) + assert val['v0'] == 1 # add field s.execute('ALTER TYPE %s ADD v1 text' % (type_name,)) val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - self.assertEqual(val['v0'], 1) - self.assertIsNone(val['v1']) + assert val['v0'] == 1 + assert val['v1'] is None s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 2, v1 : 'sometext'})" % (self.table_name,)) val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - self.assertEqual(val['v0'], 2) - self.assertEqual(val['v1'], 'sometext') + assert val['v0'] == 2 + assert val['v1'] == 'sometext' # alter field type s.execute('ALTER TYPE %s ALTER v1 TYPE blob' % (type_name,)) s.execute("INSERT INTO %s (k, v) VALUES (0, {v0 : 3, v1 : 0xdeadbeef})" % (self.table_name,)) val = s.execute('SELECT v FROM %s' % self.table_name).one()[0] - self.assertEqual(val['v0'], 3) - self.assertEqual(val['v1'], b'\xde\xad\xbe\xef') + assert val['v0'] == 3 + assert val['v1'] == b'\xde\xad\xbe\xef' @lessthancass30 def test_alter_udt(self): @@ -736,8 +737,8 @@ def test_alter_udt(self): self.session.execute(insert_statement, [1, typetoalter(1)]) results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) for result in results: - self.assertTrue(hasattr(result.typetoalter, 'a')) - self.assertFalse(hasattr(result.typetoalter, 'b')) + assert hasattr(result.typetoalter, 'a') + assert not hasattr(result.typetoalter, 'b') # Alter UDT and ensure the alter is honored in results self.session.execute("ALTER TYPE typetoalter add b int") @@ -745,5 +746,5 @@ def test_alter_udt(self): self.session.execute(insert_statement, [2, typetoalter(2, 2)]) results = self.session.execute("SELECT * from {0}".format(self.function_table_name)) for result in results: - self.assertTrue(hasattr(result.typetoalter, 'a')) - self.assertTrue(hasattr(result.typetoalter, 'b')) + assert hasattr(result.typetoalter, 'a') + assert hasattr(result.typetoalter, 'b') diff --git a/tests/integration/upgrade/__init__.py b/tests/integration/upgrade/__init__.py index c5c06c4b01..a1c751bcbd 100644 --- a/tests/integration/upgrade/__init__.py +++ b/tests/integration/upgrade/__init__.py @@ -26,7 +26,7 @@ from ccmlib.node import TimeoutError import time import logging - +import pytest import unittest @@ -78,7 +78,7 @@ def setUpClass(cls): cls.logger_handler = MockLoggingHandler() cls.logger = logging.getLogger(cluster.__name__) cls.logger.addHandler(cls.logger_handler) - + @classmethod def tearDownClass(cls): cls.logger.removeHandler(cls.logger_handler) @@ -166,7 +166,7 @@ def upgrade_node(self, node): try: node.start(wait_for_binary_proto=True, wait_other_notice=True) except TimeoutError: - self.fail("Error starting C* node while upgrading") + pytest.fail("Error starting C* node while upgrading") return True diff --git a/tests/integration/upgrade/test_upgrade.py b/tests/integration/upgrade/test_upgrade.py index 25d14427f2..fec9a38604 100644 --- a/tests/integration/upgrade/test_upgrade.py +++ b/tests/integration/upgrade/test_upgrade.py @@ -21,6 +21,7 @@ from tests.integration.upgrade import UpgradeBase, UpgradeBaseAuth, UpgradePath, upgrade_paths import unittest +import pytest # Previous Cassandra upgrade @@ -57,9 +58,9 @@ def test_can_write(self): time.sleep(0.0001) total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all").one()[0] - self.assertEqual(total_number_of_inserted, next(c)) + assert total_number_of_inserted == next(c) - self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + assert self.logger_handler.get_message_count("error", "") == 0 @two_to_three_path def test_can_connect(self): @@ -79,10 +80,10 @@ def connect_and_shutdown(): queried_hosts = set() for _ in range(10): results = session.execute("SELECT * from system.local WHERE key='local'") - self.assertGreater(len(results.current_rows), 0) - self.assertEqual(len(results.response_future.attempted_hosts), 1) + assert len(results.current_rows) > 0 + assert len(results.response_future.attempted_hosts) == 1 queried_hosts.add(results.response_future.attempted_hosts[0]) - self.assertEqual(len(queried_hosts), 3) + assert len(queried_hosts) == 3 cluster.shutdown() connect_and_shutdown() @@ -116,9 +117,9 @@ def test_can_write(self): time.sleep(0.0001) total_number_of_inserted = self.session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all").one()[0] - self.assertEqual(total_number_of_inserted, next(c)) + assert total_number_of_inserted == next(c) - self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + assert self.logger_handler.get_message_count("error", "") == 0 @two_to_three_path def test_schema_metadata_gets_refreshed(self): @@ -143,14 +144,14 @@ def test_schema_metadata_gets_refreshed(self): # Wait for the control connection to reconnect time.sleep(20) - with self.assertRaises(DriverException): + with pytest.raises(DriverException): self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=10) self.upgrade_node(nodes[0]) # Wait for the control connection to reconnect time.sleep(20) self.cluster_driver.refresh_schema_metadata(max_schema_agreement_wait=40) - self.assertNotEqual(original_meta, self.cluster_driver.metadata.keyspaces) + assert original_meta != self.cluster_driver.metadata.keyspaces @two_to_three_path def test_schema_nodes_gets_refreshed(self): @@ -176,10 +177,10 @@ def test_schema_nodes_gets_refreshed(self): self._assert_same_token_map(token_map, self.cluster_driver.metadata.token_map) def _assert_same_token_map(self, original, new): - self.assertIsNot(original, new) - self.assertEqual(original.tokens_to_hosts_by_ks, new.tokens_to_hosts_by_ks) - self.assertEqual(original.token_to_host_owner, new.token_to_host_owner) - self.assertEqual(original.ring, new.ring) + assert original is not new + assert original.tokens_to_hosts_by_ks == new.tokens_to_hosts_by_ks + assert original.token_to_host_owner == new.token_to_host_owner + assert original.ring == new.ring two_to_three_with_auth_path = upgrade_paths([ @@ -243,10 +244,10 @@ def connect_and_shutdown(self, auth_provider): queried_hosts = set() for _ in range(10): results = session.execute("SELECT * from system.local WHERE key='local'") - self.assertGreater(len(results.current_rows), 0) - self.assertEqual(len(results.response_future.attempted_hosts), 1) + assert len(results.current_rows) > 0 + assert len(results.response_future.attempted_hosts) == 1 queried_hosts.add(results.response_future.attempted_hosts[0]) - self.assertEqual(len(queried_hosts), 3) + assert len(queried_hosts) == 3 cluster.shutdown() @@ -280,6 +281,6 @@ def test_can_write_speculative(self): time.sleep(0.0001) total_number_of_inserted = session.execute("SELECT COUNT(*) from test3rf.test", execution_profile="all").one()[0] - self.assertEqual(total_number_of_inserted, next(c)) + assert total_number_of_inserted == next(c) - self.assertEqual(self.logger_handler.get_message_count("error", ""), 0) + assert self.logger_handler.get_message_count("error", "") == 0 diff --git a/tests/integration/util.py b/tests/integration/util.py index bcc4cb829b..7cbdfdb22d 100644 --- a/tests/integration/util.py +++ b/tests/integration/util.py @@ -18,14 +18,14 @@ import time -def assert_quiescent_pool_state(test_case, cluster, wait=None): +def assert_quiescent_pool_state(cluster, wait=None): """ Checking the quiescent pool state checks that none of the requests ids have been lost. However, the callback corresponding to a request_id is called before the request_id is returned back to the pool, therefore session.execute("SELECT * from system.local") - assert_quiescent_pool_state(self, session.cluster) + assert_quiescent_pool_state(session.cluster) (with no wait) might fail because when execute comes back the request_id hasn't yet been returned to the pool, therefore the wait. @@ -35,23 +35,23 @@ def assert_quiescent_pool_state(test_case, cluster, wait=None): for session in cluster.sessions: pool_states = session.get_pool_state().values() - test_case.assertTrue(pool_states) + assert pool_states for state in pool_states: - test_case.assertFalse(state['shutdown']) - test_case.assertGreater(state['open_count'], 0) + assert not state['shutdown'] + assert state['open_count'] > 0 no_in_flight = all((i == 0 for i in state['in_flights'])) orphans_and_inflights = zip(state['orphan_requests'],state['in_flights']) all_orphaned = all((len(orphans) == inflight for (orphans,inflight) in orphans_and_inflights)) - test_case.assertTrue(no_in_flight or all_orphaned) + assert no_in_flight or all_orphaned for holder in cluster.get_connection_holders(): for connection in holder.get_connections(): # all ids are unique req_ids = connection.request_ids orphan_ids = connection.orphaned_request_ids - test_case.assertEqual(len(req_ids), len(set(req_ids))) - test_case.assertEqual(connection.highest_request_id, len(req_ids) + len(orphan_ids) - 1) - test_case.assertEqual(connection.highest_request_id, max(chain(req_ids, orphan_ids))) + assert len(req_ids) == len(set(req_ids)) + assert connection.highest_request_id == len(req_ids) + len(orphan_ids) - 1 + assert connection.highest_request_id == max(chain(req_ids, orphan_ids)) if PROTOCOL_VERSION < 3: - test_case.assertEqual(connection.highest_request_id, connection.max_request_id) + assert connection.highest_request_id == connection.max_request_id diff --git a/tests/stress_tests/test_multi_inserts.py b/tests/stress_tests/test_multi_inserts.py index 84dfc5e6f7..1a5a596aec 100644 --- a/tests/stress_tests/test_multi_inserts.py +++ b/tests/stress_tests/test_multi_inserts.py @@ -77,4 +77,4 @@ def test_in_flight_is_one(self): break i = i + 1 - self.assertFalse(leaking_connections, 'Detected leaking connection after %s iterations' % i) + assert not leaking_connections, 'Detected leaking connection after %s iterations' % i diff --git a/tests/unit/advanced/test_execution_profile.py b/tests/unit/advanced/test_execution_profile.py index 478322f95b..bc51388c00 100644 --- a/tests/unit/advanced/test_execution_profile.py +++ b/tests/unit/advanced/test_execution_profile.py @@ -23,9 +23,9 @@ class GraphExecutionProfileTest(unittest.TestCase): def test_graph_source_can_be_set_with_graph_execution_profile(self): options = GraphOptions(graph_source='a') ep = GraphExecutionProfile(graph_options=options) - self.assertEqual(ep.graph_options.graph_source, b'a') + assert ep.graph_options.graph_source == b'a' def test_graph_source_is_preserve_with_graph_analytics_execution_profile(self): options = GraphOptions(graph_source='doesnt_matter') ep = GraphAnalyticsExecutionProfile(graph_options=options) - self.assertEqual(ep.graph_options.graph_source, b'a') # graph source is set automatically + assert ep.graph_options.graph_source == b'a' # graph source is set automatically diff --git a/tests/unit/advanced/test_geometry.py b/tests/unit/advanced/test_geometry.py index d85f1bc293..1927b51da7 100644 --- a/tests/unit/advanced/test_geometry.py +++ b/tests/unit/advanced/test_geometry.py @@ -20,6 +20,7 @@ from cassandra.protocol import ProtocolVersion from cassandra.cqltypes import PointType, LineStringType, PolygonType, WKBGeometryType from cassandra.util import Point, LineString, Polygon, _LinearRing, Distance, _HAS_GEOMET +import pytest wkb_be = 0 wkb_le = 1 @@ -35,12 +36,12 @@ def test_marshal_platform(self): for proto_ver in protocol_versions: for geo in self.samples: cql_type = lookup_casstype(geo.__class__.__name__ + 'Type') - self.assertEqual(cql_type.from_binary(cql_type.to_binary(geo, proto_ver), proto_ver), geo) + assert cql_type.from_binary(cql_type.to_binary(geo, proto_ver), proto_ver) == geo def _verify_both_endian(self, typ, body_fmt, params, expected): for proto_ver in protocol_versions: - self.assertEqual(typ.from_binary(struct.pack(">BI" + body_fmt, wkb_be, *params), proto_ver), expected) - self.assertEqual(typ.from_binary(struct.pack("BI" + body_fmt, wkb_be, *params), proto_ver) == expected + assert typ.from_binary(struct.pack(" base map base = GraphOptions(**self.api_params) - self.assertEqual(GraphOptions().get_options_map(base), base._graph_options) + assert GraphOptions().get_options_map(base) == base._graph_options # something set overrides kwargs = self.api_params.copy() # this test concept got strange after we added default values for a couple GraphOption attrs @@ -276,9 +282,9 @@ def test_get_options(self): other = GraphOptions(**kwargs) options = base.get_options_map(other) updated = self.opt_mapping['graph_name'] - self.assertEqual(options[updated], b'unit_test') + assert options[updated] == b'unit_test' for name in (n for n in self.opt_mapping.values() if n != updated): - self.assertEqual(options[name], base._graph_options[name]) + assert options[name] == base._graph_options[name] # base unchanged self._verify_api_params(base, self.api_params) @@ -286,24 +292,24 @@ def test_get_options(self): def test_set_attr(self): expected = 'test@@@@' opts = GraphOptions(graph_name=expected) - self.assertEqual(opts.graph_name, expected.encode()) + assert opts.graph_name == expected.encode() expected = 'somethingelse####' opts.graph_name = expected - self.assertEqual(opts.graph_name, expected.encode()) + assert opts.graph_name == expected.encode() # will update options with set value another = GraphOptions() - self.assertIsNone(another.graph_name) + assert another.graph_name is None another.update(opts) - self.assertEqual(another.graph_name, expected.encode()) + assert another.graph_name == expected.encode() opts.graph_name = None - self.assertIsNone(opts.graph_name) + assert opts.graph_name is None # will not update another with its set-->unset value another.update(opts) - self.assertEqual(another.graph_name, expected.encode()) # remains unset + assert another.graph_name == expected.encode() # remains unset opt_map = another.get_options_map(opts) - self.assertEqual(opt_map, another._graph_options) + assert opt_map == another._graph_options def test_del_attr(self): opts = GraphOptions(**self.api_params) @@ -313,14 +319,14 @@ def test_del_attr(self): self._verify_api_params(opts, test_params) def _verify_api_params(self, opts, api_params): - self.assertEqual(len(opts._graph_options), len(api_params)) + assert len(opts._graph_options) == len(api_params) for name, value in api_params.items(): try: value = value.encode() except: pass # already bytes - self.assertEqual(getattr(opts, name), value) - self.assertEqual(opts._graph_options[self.opt_mapping[name]], value) + assert getattr(opts, name) == value + assert opts._graph_options[self.opt_mapping[name]] == value def test_consistency_levels(self): read_cl = ConsistencyLevel.ONE @@ -328,49 +334,49 @@ def test_consistency_levels(self): # set directly opts = GraphOptions(graph_read_consistency_level=read_cl, graph_write_consistency_level=write_cl) - self.assertEqual(opts.graph_read_consistency_level, read_cl) - self.assertEqual(opts.graph_write_consistency_level, write_cl) + assert opts.graph_read_consistency_level == read_cl + assert opts.graph_write_consistency_level == write_cl # mapping from base opt_map = opts.get_options_map() - self.assertEqual(opt_map['graph-read-consistency'], ConsistencyLevel.value_to_name[read_cl].encode()) - self.assertEqual(opt_map['graph-write-consistency'], ConsistencyLevel.value_to_name[write_cl].encode()) + assert opt_map['graph-read-consistency'] == ConsistencyLevel.value_to_name[read_cl].encode() + assert opt_map['graph-write-consistency'] == ConsistencyLevel.value_to_name[write_cl].encode() # empty by default new_opts = GraphOptions() opt_map = new_opts.get_options_map() - self.assertNotIn('graph-read-consistency', opt_map) - self.assertNotIn('graph-write-consistency', opt_map) + assert 'graph-read-consistency' not in opt_map + assert 'graph-write-consistency' not in opt_map # set from other opt_map = new_opts.get_options_map(opts) - self.assertEqual(opt_map['graph-read-consistency'], ConsistencyLevel.value_to_name[read_cl].encode()) - self.assertEqual(opt_map['graph-write-consistency'], ConsistencyLevel.value_to_name[write_cl].encode()) + assert opt_map['graph-read-consistency'] == ConsistencyLevel.value_to_name[read_cl].encode() + assert opt_map['graph-write-consistency'] == ConsistencyLevel.value_to_name[write_cl].encode() def test_graph_source_convenience_attributes(self): opts = GraphOptions() - self.assertEqual(opts.graph_source, b'g') - self.assertFalse(opts.is_analytics_source) - self.assertTrue(opts.is_graph_source) - self.assertFalse(opts.is_default_source) + assert opts.graph_source == b'g' + assert not opts.is_analytics_source + assert opts.is_graph_source + assert not opts.is_default_source opts.set_source_default() - self.assertIsNotNone(opts.graph_source) - self.assertFalse(opts.is_analytics_source) - self.assertFalse(opts.is_graph_source) - self.assertTrue(opts.is_default_source) + assert opts.graph_source is not None + assert not opts.is_analytics_source + assert not opts.is_graph_source + assert opts.is_default_source opts.set_source_analytics() - self.assertIsNotNone(opts.graph_source) - self.assertTrue(opts.is_analytics_source) - self.assertFalse(opts.is_graph_source) - self.assertFalse(opts.is_default_source) + assert opts.graph_source is not None + assert opts.is_analytics_source + assert not opts.is_graph_source + assert not opts.is_default_source opts.set_source_graph() - self.assertIsNotNone(opts.graph_source) - self.assertFalse(opts.is_analytics_source) - self.assertTrue(opts.is_graph_source) - self.assertFalse(opts.is_default_source) + assert opts.graph_source is not None + assert not opts.is_analytics_source + assert opts.is_graph_source + assert not opts.is_default_source class GraphStatementTests(unittest.TestCase): @@ -384,11 +390,12 @@ def test_init(self): 'custom_payload': object()} statement = SimpleGraphStatement(**kwargs) for k, v in kwargs.items(): - self.assertIs(getattr(statement, k), v) + assert getattr(statement, k) is v # but not a bogus parameter kwargs['bogus'] = object() - self.assertRaises(TypeError, SimpleGraphStatement, **kwargs) + with pytest.raises(TypeError): + SimpleGraphStatement(**kwargs) class GraphRowFactoryTests(unittest.TestCase): @@ -396,12 +403,12 @@ class GraphRowFactoryTests(unittest.TestCase): def test_object_row_factory(self): col_names = [] # unused rows = [object() for _ in range(10)] - self.assertEqual(single_object_row_factory(col_names, ((o,) for o in rows)), rows) + assert single_object_row_factory(col_names, ((o,) for o in rows)) == rows def test_graph_result_row_factory(self): col_names = [] # unused rows = [json.dumps({'result': i}) for i in range(10)] results = graph_result_row_factory(col_names, ((o,) for o in rows)) for i, res in enumerate(results): - self.assertIsInstance(res, Result) - self.assertEqual(res.value, i) + assert isinstance(res, Result) + assert res.value == i diff --git a/tests/unit/advanced/test_insights.py b/tests/unit/advanced/test_insights.py index 4047fe12b8..6646e6746f 100644 --- a/tests/unit/advanced/test_insights.py +++ b/tests/unit/advanced/test_insights.py @@ -66,18 +66,15 @@ class NoConfAsDict(object): # no default # ... as a policy - self.assertEqual(insights_registry.serialize(obj, policy=True), - {'type': 'NoConfAsDict', - 'namespace': ns, - 'options': {}}) + assert insights_registry.serialize(obj, policy=True) == {'type': 'NoConfAsDict', + 'namespace': ns, + 'options': {}} # ... not as a policy (default) - self.assertEqual(insights_registry.serialize(obj), - {'type': 'NoConfAsDict', - 'namespace': ns, - }) + assert insights_registry.serialize(obj) == {'type': 'NoConfAsDict', + 'namespace': ns, + } # with default - self.assertIs(insights_registry.serialize(obj, default=sentinel.attr_err_default), - sentinel.attr_err_default) + assert insights_registry.serialize(obj, default=sentinel.attr_err_default) is sentinel.attr_err_default def test_successful_return(self): @@ -91,14 +88,11 @@ class SubclassSentinel(SuperclassSentinel): def superclass_sentinel_serializer(obj): return sentinel.serialized_superclass - self.assertIs(insights_registry.serialize(SuperclassSentinel()), - sentinel.serialized_superclass) - self.assertIs(insights_registry.serialize(SubclassSentinel()), - sentinel.serialized_superclass) + assert insights_registry.serialize(SuperclassSentinel()) is sentinel.serialized_superclass + assert insights_registry.serialize(SubclassSentinel()) is sentinel.serialized_superclass # with default -- same behavior - self.assertIs(insights_registry.serialize(SubclassSentinel(), default=object()), - sentinel.serialized_superclass) + assert insights_registry.serialize(SubclassSentinel(), default=object()) is sentinel.serialized_superclass class TestConfigAsDict(unittest.TestCase): @@ -116,190 +110,145 @@ def test_graph_options(self): log.debug(go._graph_options) - self.assertEqual( - insights_registry.serialize(go), - {'source': 'source_for_test', - 'language': 'lang_for_test', - 'graphProtocol': 'protocol_for_test', - # no graph_invalid_option - } - ) + assert insights_registry.serialize(go) == {'source': 'source_for_test', + 'language': 'lang_for_test', + 'graphProtocol': 'protocol_for_test', + # no graph_invalid_option + } # cluster.py def test_execution_profile(self): self.maxDiff = None - self.assertEqual( - insights_registry.serialize(ExecutionProfile()), - {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': False}, - 'type': 'TokenAwarePolicy'}, - 'readTimeout': 10.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'RetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': None - } - ) + assert insights_registry.serialize(ExecutionProfile()) == {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}, + 'readTimeout': 10.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'RetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': None + } def test_graph_execution_profile(self): self.maxDiff = None - self.assertEqual( - insights_registry.serialize(GraphExecutionProfile()), - {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': False}, - 'type': 'TokenAwarePolicy'}, - 'readTimeout': 30.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': {'graphProtocol': None, - 'language': 'gremlin-groovy', - 'source': 'g'}, - } - ) + assert insights_registry.serialize(GraphExecutionProfile()) == {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}, + 'readTimeout': 30.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': {'graphProtocol': None, + 'language': 'gremlin-groovy', + 'source': 'g'}, + } def test_graph_analytics_execution_profile(self): self.maxDiff = None - self.assertEqual( - insights_registry.serialize(GraphAnalyticsExecutionProfile()), - {'consistency': 'LOCAL_ONE', - 'continuousPagingOptions': None, - 'loadBalancing': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', - 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'}, - 'shuffle_replicas': False}, - 'type': 'TokenAwarePolicy'}}, - 'type': 'DefaultLoadBalancingPolicy'}, - 'readTimeout': 604800.0, - 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, - 'serialConsistency': None, - 'speculativeExecution': {'namespace': 'cassandra.policies', - 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, - 'graphOptions': {'graphProtocol': None, - 'language': 'gremlin-groovy', - 'source': 'a'}, - } - ) + assert insights_registry.serialize(GraphAnalyticsExecutionProfile()) == {'consistency': 'LOCAL_ONE', + 'continuousPagingOptions': None, + 'loadBalancing': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', + 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'}}, + 'type': 'DefaultLoadBalancingPolicy'}, + 'readTimeout': 604800.0, + 'retry': {'namespace': 'cassandra.policies', 'options': {}, 'type': 'NeverRetryPolicy'}, + 'serialConsistency': None, + 'speculativeExecution': {'namespace': 'cassandra.policies', + 'options': {}, 'type': 'NoSpeculativeExecutionPolicy'}, + 'graphOptions': {'graphProtocol': None, + 'language': 'gremlin-groovy', + 'source': 'a'}, + } # policies.py def test_DC_aware_round_robin_policy(self): - self.assertEqual( - insights_registry.serialize(DCAwareRoundRobinPolicy()), - {'namespace': 'cassandra.policies', - 'options': {'local_dc': '', 'used_hosts_per_remote_dc': 0}, - 'type': 'DCAwareRoundRobinPolicy'} - ) - self.assertEqual( - insights_registry.serialize(DCAwareRoundRobinPolicy(local_dc='fake_local_dc', - used_hosts_per_remote_dc=15)), - {'namespace': 'cassandra.policies', - 'options': {'local_dc': 'fake_local_dc', 'used_hosts_per_remote_dc': 15}, - 'type': 'DCAwareRoundRobinPolicy'} - ) + assert insights_registry.serialize(DCAwareRoundRobinPolicy()) == {'namespace': 'cassandra.policies', + 'options': {'local_dc': '', 'used_hosts_per_remote_dc': 0}, + 'type': 'DCAwareRoundRobinPolicy'} + assert insights_registry.serialize(DCAwareRoundRobinPolicy(local_dc='fake_local_dc', + used_hosts_per_remote_dc=15)) == {'namespace': 'cassandra.policies', + 'options': {'local_dc': 'fake_local_dc', 'used_hosts_per_remote_dc': 15}, + 'type': 'DCAwareRoundRobinPolicy'} def test_token_aware_policy(self): - self.assertEqual( - insights_registry.serialize(TokenAwarePolicy(child_policy=LoadBalancingPolicy())), - {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'}, - 'shuffle_replicas': False}, - 'type': 'TokenAwarePolicy'} - ) + assert insights_registry.serialize(TokenAwarePolicy(child_policy=LoadBalancingPolicy())) == {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'}, + 'shuffle_replicas': False}, + 'type': 'TokenAwarePolicy'} def test_whitelist_round_robin_policy(self): - self.assertEqual( - insights_registry.serialize(WhiteListRoundRobinPolicy(['127.0.0.3'])), - {'namespace': 'cassandra.policies', - 'options': {'allowed_hosts': ('127.0.0.3',)}, - 'type': 'WhiteListRoundRobinPolicy'} - ) + assert insights_registry.serialize(WhiteListRoundRobinPolicy(['127.0.0.3'])) == {'namespace': 'cassandra.policies', + 'options': {'allowed_hosts': ('127.0.0.3',)}, + 'type': 'WhiteListRoundRobinPolicy'} def test_host_filter_policy(self): def my_predicate(s): return False - self.assertEqual( - insights_registry.serialize(HostFilterPolicy(LoadBalancingPolicy(), my_predicate)), - {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'}, - 'predicate': 'my_predicate'}, - 'type': 'HostFilterPolicy'} - ) + assert insights_registry.serialize(HostFilterPolicy(LoadBalancingPolicy(), my_predicate)) == {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'}, + 'predicate': 'my_predicate'}, + 'type': 'HostFilterPolicy'} def test_constant_reconnection_policy(self): - self.assertEqual( - insights_registry.serialize(ConstantReconnectionPolicy(3, 200)), - {'type': 'ConstantReconnectionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'delay': 3, 'max_attempts': 200} - } - ) + assert insights_registry.serialize(ConstantReconnectionPolicy(3, 200)) == {'type': 'ConstantReconnectionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'delay': 3, 'max_attempts': 200} + } def test_exponential_reconnection_policy(self): - self.assertEqual( - insights_registry.serialize(ExponentialReconnectionPolicy(4, 100, 10)), - {'type': 'ExponentialReconnectionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10} - } - ) + assert insights_registry.serialize(ExponentialReconnectionPolicy(4, 100, 10)) == {'type': 'ExponentialReconnectionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'base_delay': 4, 'max_delay': 100, 'max_attempts': 10} + } def test_retry_policy(self): - self.assertEqual( - insights_registry.serialize(RetryPolicy()), - {'type': 'RetryPolicy', - 'namespace': 'cassandra.policies', - 'options': {} - } - ) + assert insights_registry.serialize(RetryPolicy()) == {'type': 'RetryPolicy', + 'namespace': 'cassandra.policies', + 'options': {} + } def test_spec_exec_policy(self): - self.assertEqual( - insights_registry.serialize(SpeculativeExecutionPolicy()), - {'type': 'SpeculativeExecutionPolicy', - 'namespace': 'cassandra.policies', - 'options': {} - } - ) + assert insights_registry.serialize(SpeculativeExecutionPolicy()) == {'type': 'SpeculativeExecutionPolicy', + 'namespace': 'cassandra.policies', + 'options': {} + } def test_constant_spec_exec_policy(self): - self.assertEqual( - insights_registry.serialize(ConstantSpeculativeExecutionPolicy(100, 101)), - {'type': 'ConstantSpeculativeExecutionPolicy', - 'namespace': 'cassandra.policies', - 'options': {'delay': 100, - 'max_attempts': 101} - } - ) + assert insights_registry.serialize(ConstantSpeculativeExecutionPolicy(100, 101)) == {'type': 'ConstantSpeculativeExecutionPolicy', + 'namespace': 'cassandra.policies', + 'options': {'delay': 100, + 'max_attempts': 101} + } def test_wrapper_policy(self): - self.assertEqual( - insights_registry.serialize(WrapperPolicy(LoadBalancingPolicy())), - {'namespace': 'cassandra.policies', - 'options': {'child_policy': {'namespace': 'cassandra.policies', - 'options': {}, - 'type': 'LoadBalancingPolicy'} - }, - 'type': 'WrapperPolicy'} - ) + assert insights_registry.serialize(WrapperPolicy(LoadBalancingPolicy())) == {'namespace': 'cassandra.policies', + 'options': {'child_policy': {'namespace': 'cassandra.policies', + 'options': {}, + 'type': 'LoadBalancingPolicy'} + }, + 'type': 'WrapperPolicy'} diff --git a/tests/unit/advanced/test_metadata.py b/tests/unit/advanced/test_metadata.py index 20f80b4da4..5ccfa5e477 100644 --- a/tests/unit/advanced/test_metadata.py +++ b/tests/unit/advanced/test_metadata.py @@ -48,95 +48,71 @@ def _create_table_metadata(self, with_vertex=False, with_edge=False): def test_keyspace_no_graph_engine(self): km = self._create_keyspace_metadata(None) - self.assertEqual(km.graph_engine, None) - self.assertNotIn( - "graph_engine", - km.as_cql_query() - ) + assert km.graph_engine == None + assert "graph_engine" not in km.as_cql_query() def test_keyspace_with_graph_engine(self): graph_engine = 'Core' km = self._create_keyspace_metadata(graph_engine) - self.assertEqual(km.graph_engine, graph_engine) + assert km.graph_engine == graph_engine cql = km.as_cql_query() - self.assertIn( - "graph_engine", - cql - ) - self.assertIn( - "Core", - cql - ) + assert "graph_engine" in cql + assert "Core" in cql def test_table_no_vertex_or_edge(self): tm = self._create_table_metadata() - self.assertIsNone(tm.vertex) - self.assertIsNone(tm.edge) + assert tm.vertex is None + assert tm.edge is None cql = tm.as_cql_query() - self.assertNotIn("VERTEX LABEL", cql) - self.assertNotIn("EDGE LABEL", cql) + assert "VERTEX LABEL" not in cql + assert "EDGE LABEL" not in cql def test_table_with_vertex(self): tm = self._create_table_metadata(with_vertex=True) - self.assertIsInstance(tm.vertex, VertexMetadata) - self.assertIsNone(tm.edge) + assert isinstance(tm.vertex, VertexMetadata) + assert tm.edge is None cql = tm.as_cql_query() - self.assertIn("VERTEX LABEL", cql) - self.assertNotIn("EDGE LABEL", cql) + assert "VERTEX LABEL" in cql + assert "EDGE LABEL" not in cql def test_table_with_edge(self): tm = self._create_table_metadata(with_edge=True) - self.assertIsNone(tm.vertex) - self.assertIsInstance(tm.edge, EdgeMetadata) + assert tm.vertex is None + assert isinstance(tm.edge, EdgeMetadata) cql = tm.as_cql_query() - self.assertNotIn("VERTEX LABEL", cql) - self.assertIn("EDGE LABEL", cql) - self.assertIn("FROM from_label", cql) - self.assertIn("TO to_label", cql) + assert "VERTEX LABEL" not in cql + assert "EDGE LABEL" in cql + assert "FROM from_label" in cql + assert "TO to_label" in cql def test_vertex_with_label(self): tm = self. _create_table_metadata(with_vertex=True) - self.assertTrue(tm.as_cql_query().endswith('VERTEX LABEL label')) + assert tm.as_cql_query().endswith('VERTEX LABEL label') def test_edge_single_partition_key_and_clustering_key(self): tm = self._create_table_metadata(with_edge=True) - self.assertIn( - 'FROM from_label(pk1, c1)', - tm.as_cql_query() - ) + assert 'FROM from_label(pk1, c1)' in tm.as_cql_query() def test_edge_multiple_partition_keys(self): edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2']) tm = self. _create_table_metadata(with_edge=edge) - self.assertIn( - 'FROM from_label((pk1, pk2), ', - tm.as_cql_query() - ) + assert 'FROM from_label((pk1, pk2), ' in tm.as_cql_query() def test_edge_no_clustering_keys(self): edge = self._create_edge_metadata(clustering_keys=[]) tm = self. _create_table_metadata(with_edge=edge) - self.assertIn( - 'FROM from_label(pk1) ', - tm.as_cql_query() - ) + assert 'FROM from_label(pk1) ' in tm.as_cql_query() def test_edge_multiple_clustering_keys(self): edge = self._create_edge_metadata(clustering_keys=['c1', 'c2']) tm = self. _create_table_metadata(with_edge=edge) - self.assertIn( - 'FROM from_label(pk1, c1, c2) ', - tm.as_cql_query() - ) + assert 'FROM from_label(pk1, c1, c2) ' in tm.as_cql_query() def test_edge_multiple_partition_and_clustering_keys(self): edge = self._create_edge_metadata(partition_keys=['pk1', 'pk2'], clustering_keys=['c1', 'c2']) tm = self. _create_table_metadata(with_edge=edge) - self.assertIn( - 'FROM from_label((pk1, pk2), c1, c2) ', - tm.as_cql_query() - ) + assert 'FROM from_label((pk1, pk2), c1, c2) ' in tm.as_cql_query() class SchemaParsersTests(unittest.TestCase): @@ -159,16 +135,14 @@ def wait_for_responses(self, *msgs, **kwargs): p._query_all() for q in conn.queries: - if "USING TIMEOUT" in q.query: - self.fail(f"<{schemaClass.__name__}> query `{q.query}` contains `USING TIMEOUT`, while should not") + assert "USING TIMEOUT" not in q.query, f"<{schemaClass.__name__}> query `{q.query}` contains `USING TIMEOUT`, while should not" conn = FakeConnection() p = schemaClass(conn, 2.0, 1000, datetime.timedelta(seconds=2)) p._query_all() for q in conn.queries: - if "USING TIMEOUT 2000ms" not in q.query: - self.fail(f"{schemaClass.__name__} query `{q.query}` does not contain `USING TIMEOUT 2000ms`") + assert "USING TIMEOUT 2000ms" in q.query, f"{schemaClass.__name__} query `{q.query}` does not contain `USING TIMEOUT 2000ms`" def get_all_schema_parser_classes(cl): diff --git a/tests/unit/advanced/test_policies.py b/tests/unit/advanced/test_policies.py index 553e7dba87..8e421a859d 100644 --- a/tests/unit/advanced/test_policies.py +++ b/tests/unit/advanced/test_policies.py @@ -37,7 +37,7 @@ def test_no_target(self): policy.populate(Mock(metadata=ClusterMetaMock()), hosts) for _ in range(node_count): query_plan = list(policy.make_query_plan(None, Mock(target_host=None))) - self.assertEqual(sorted(query_plan), hosts) + assert sorted(query_plan) == hosts def test_status_updates(self): node_count = 4 @@ -49,7 +49,7 @@ def test_status_updates(self): policy.on_up(4) policy.on_add(5) query_plan = list(policy.make_query_plan()) - self.assertEqual(sorted(query_plan), [2, 3, 4, 5]) + assert sorted(query_plan) == [2, 3, 4, 5] def test_no_live_nodes(self): hosts = [0, 1, 2, 3] @@ -60,7 +60,7 @@ def test_no_live_nodes(self): policy.on_down(i) query_plan = list(policy.make_query_plan()) - self.assertEqual(query_plan, []) + assert query_plan == [] def test_target_no_host(self): node_count = 4 @@ -68,7 +68,7 @@ def test_target_no_host(self): policy = DSELoadBalancingPolicy(RoundRobinPolicy()) policy.populate(Mock(metadata=ClusterMetaMock()), hosts) query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) - self.assertEqual(sorted(query_plan), hosts) + assert sorted(query_plan) == hosts def test_target_host_down(self): node_count = 4 @@ -78,12 +78,12 @@ def test_target_host_down(self): policy = DSELoadBalancingPolicy(RoundRobinPolicy()) policy.populate(Mock(metadata=ClusterMetaMock({'127.0.0.1': target_host})), hosts) query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) - self.assertEqual(sorted(query_plan), hosts) + assert sorted(query_plan) == hosts target_host.is_up = False policy.on_down(target_host) query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) - self.assertNotIn(target_host, query_plan) + assert target_host not in query_plan def test_target_host_nominal(self): node_count = 4 @@ -95,5 +95,5 @@ def test_target_host_nominal(self): policy.populate(Mock(metadata=ClusterMetaMock({'127.0.0.1': target_host})), hosts) for _ in range(10): query_plan = list(policy.make_query_plan(None, Mock(target_host='127.0.0.1'))) - self.assertEqual(sorted(query_plan), hosts) - self.assertEqual(query_plan[0], target_host) + assert sorted(query_plan) == hosts + assert query_plan[0] == target_host diff --git a/tests/unit/column_encryption/test_policies.py b/tests/unit/column_encryption/test_policies.py index 27e7c62ce7..1bd83ecf89 100644 --- a/tests/unit/column_encryption/test_policies.py +++ b/tests/unit/column_encryption/test_policies.py @@ -18,6 +18,7 @@ from cassandra.policies import ColDesc from cassandra.column_encryption.policies import AES256ColumnEncryptionPolicy, \ AES256_BLOCK_SIZE_BYTES, AES256_KEY_SIZE_BYTES +import pytest @unittest.skip("Skip until https://github.com/scylladb/python-driver/issues/365 is sorted out") class AES256ColumnEncryptionPolicyTest(unittest.TestCase): @@ -33,7 +34,7 @@ def _test_round_trip(self, bytes): policy = AES256ColumnEncryptionPolicy() policy.add_column(coldesc, self._random_key(), "blob") encrypted_bytes = policy.encrypt(coldesc, bytes) - self.assertEqual(bytes, policy.decrypt(coldesc, encrypted_bytes)) + assert bytes == policy.decrypt(coldesc, encrypted_bytes) def test_no_padding_necessary(self): self._test_round_trip(self._random_block()) @@ -50,10 +51,10 @@ def test_add_column_invalid_key_size_raises(self): coldesc = ColDesc('ks1','table1','col1') policy = AES256ColumnEncryptionPolicy() for key_size in range(1,AES256_KEY_SIZE_BYTES - 1): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy.add_column(coldesc, os.urandom(key_size), "blob") for key_size in range(AES256_KEY_SIZE_BYTES + 1,(2 * AES256_KEY_SIZE_BYTES) - 1): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy.add_column(coldesc, os.urandom(key_size), "blob") def test_add_column_invalid_iv_size_raises(self): @@ -64,54 +65,54 @@ def test_iv_size(iv_size): coldesc = ColDesc('ks1','table1','col1') for iv_size in range(1,AES256_BLOCK_SIZE_BYTES - 1): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): test_iv_size(iv_size) for iv_size in range(AES256_BLOCK_SIZE_BYTES + 1,(2 * AES256_BLOCK_SIZE_BYTES) - 1): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): test_iv_size(iv_size) # Finally, confirm that the expected IV size has no issue test_iv_size(AES256_BLOCK_SIZE_BYTES) def test_add_column_null_coldesc_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() policy.add_column(None, self._random_block(), "blob") def test_add_column_null_key_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, None, "blob") def test_add_column_null_type_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_block(), None) def test_add_column_unknown_type_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_block(), "foobar") def test_encode_and_encrypt_null_coldesc_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_key(), "blob") policy.encode_and_encrypt(None, self._random_block()) def test_encode_and_encrypt_null_obj_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_key(), "blob") policy.encode_and_encrypt(coldesc, None) def test_encode_and_encrypt_unknown_coldesc_raises(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_key(), "blob") @@ -121,14 +122,14 @@ def test_contains_column(self): coldesc = ColDesc('ks1','table1','col1') policy = AES256ColumnEncryptionPolicy() policy.add_column(coldesc, self._random_key(), "blob") - self.assertTrue(policy.contains_column(coldesc)) - self.assertFalse(policy.contains_column(ColDesc('ks2','table1','col1'))) - self.assertFalse(policy.contains_column(ColDesc('ks1','table2','col1'))) - self.assertFalse(policy.contains_column(ColDesc('ks1','table1','col2'))) - self.assertFalse(policy.contains_column(ColDesc('ks2','table2','col2'))) + assert policy.contains_column(coldesc) + assert not policy.contains_column(ColDesc('ks2','table1','col1')) + assert not policy.contains_column(ColDesc('ks1','table2','col1')) + assert not policy.contains_column(ColDesc('ks1','table1','col2')) + assert not policy.contains_column(ColDesc('ks2','table2','col2')) def test_encrypt_unknown_column(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy = AES256ColumnEncryptionPolicy() coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_key(), "blob") @@ -139,7 +140,7 @@ def test_decrypt_unknown_column(self): coldesc = ColDesc('ks1','table1','col1') policy.add_column(coldesc, self._random_key(), "blob") encrypted_bytes = policy.encrypt(coldesc, self._random_block()) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): policy.decrypt(ColDesc('ks2','table2','col2'), encrypted_bytes) def test_cache_info(self): @@ -157,14 +158,14 @@ def test_cache_info(self): for _ in range(10): policy.encrypt(coldesc1, self._random_block()) cache_info = policy.cache_info() - self.assertEqual(cache_info.hits, 9) - self.assertEqual(cache_info.misses, 1) - self.assertEqual(cache_info.maxsize, 128) + assert cache_info.hits == 9 + assert cache_info.misses == 1 + assert cache_info.maxsize == 128 # Important note: we're measuring the size of the cache of ciphers, NOT stored # keys. We won't have a cipher here until we actually encrypt something - self.assertEqual(cache_info.currsize, 1) + assert cache_info.currsize == 1 policy.encrypt(coldesc2, self._random_block()) - self.assertEqual(policy.cache_info().currsize, 2) + assert policy.cache_info().currsize == 2 policy.encrypt(coldesc3, self._random_block()) - self.assertEqual(policy.cache_info().currsize, 3) + assert policy.cache_info().currsize == 3 diff --git a/tests/unit/cqlengine/test_columns.py b/tests/unit/cqlengine/test_columns.py index a7bf74ec23..cba57e88a6 100644 --- a/tests/unit/cqlengine/test_columns.py +++ b/tests/unit/cqlengine/test_columns.py @@ -22,41 +22,41 @@ class ColumnTest(unittest.TestCase): def test_comparisons(self): c0 = Column() c1 = Column() - self.assertEqual(c1.position - c0.position, 1) + assert c1.position - c0.position == 1 # __ne__ - self.assertNotEqual(c0, c1) - self.assertNotEqual(c0, object()) + assert c0 != c1 + assert c0 != object() # __eq__ - self.assertEqual(c0, c0) - self.assertFalse(c0 == object()) + assert c0 == c0 + assert not c0 == object() # __lt__ - self.assertLess(c0, c1) + assert c0 < c1 try: c0 < object() # this raises for Python 3 except TypeError: pass # __le__ - self.assertLessEqual(c0, c1) - self.assertLessEqual(c0, c0) + assert c0 <= c1 + assert c0 <= c0 try: c0 <= object() # this raises for Python 3 except TypeError: pass # __gt__ - self.assertGreater(c1, c0) + assert c1 > c0 try: c1 > object() # this raises for Python 3 except TypeError: pass # __ge__ - self.assertGreaterEqual(c1, c0) - self.assertGreaterEqual(c1, c1) + assert c1 >= c0 + assert c1 >= c1 try: c1 >= object() # this raises for Python 3 except TypeError: @@ -64,5 +64,5 @@ def test_comparisons(self): def test_hash(self): c0 = Column() - self.assertEqual(id(c0), c0.__hash__()) + assert id(c0) == c0.__hash__() diff --git a/tests/unit/cqlengine/test_connection.py b/tests/unit/cqlengine/test_connection.py index 76266cff23..9bce715f4e 100644 --- a/tests/unit/cqlengine/test_connection.py +++ b/tests/unit/cqlengine/test_connection.py @@ -18,6 +18,7 @@ from cassandra.cluster import _ConfigMode from cassandra.cqlengine import connection from cassandra.query import dict_factory +import pytest class ConnectionTest(unittest.TestCase): @@ -26,10 +27,7 @@ class ConnectionTest(unittest.TestCase): def setUp(self): super(ConnectionTest, self).setUp() - self.assertFalse( - connection._connections, - 'Test precondition not met: connections are registered: {cs}'.format(cs=connection._connections) - ) + assert not connection._connections, 'Test precondition not met: connections are registered: {cs}'.format(cs=connection._connections) def test_set_session_without_existing_connection(self): """ @@ -49,12 +47,12 @@ def test_get_session_fails_without_existing_connection(self): """ Users can't get the default session without having a default connection set. """ - with self.assertRaisesRegex(connection.CQLEngineException, self.no_registered_connection_msg): + with pytest.raises(connection.CQLEngineException, match=self.no_registered_connection_msg): connection.get_session(connection=None) def test_get_cluster_fails_without_existing_connection(self): """ Users can't get the default cluster without having a default connection set. """ - with self.assertRaisesRegex(connection.CQLEngineException, self.no_registered_connection_msg): + with pytest.raises(connection.CQLEngineException, match=self.no_registered_connection_msg): connection.get_cluster(connection=None) diff --git a/tests/unit/cython/bytesio_testhelper.pyx b/tests/unit/cython/bytesio_testhelper.pyx index 7ba91bc4c0..595cd29cc8 100644 --- a/tests/unit/cython/bytesio_testhelper.pyx +++ b/tests/unit/cython/bytesio_testhelper.pyx @@ -13,32 +13,27 @@ # limitations under the License. from cassandra.bytesio cimport BytesIOReader +import pytest -def test_read1(assert_equal, assert_raises): +def test_read1(): cdef BytesIOReader reader = BytesIOReader(b'abcdef') - assert_equal(reader.read(2)[:2], b'ab') - assert_equal(reader.read(2)[:2], b'cd') - assert_equal(reader.read(0)[:0], b'') - assert_equal(reader.read(2)[:2], b'ef') + assert reader.read(2)[:2] == b'ab' + assert reader.read(2)[:2] == b'cd' + assert reader.read(0)[:0] == b'' + assert reader.read(2)[:2] == b'ef' -def test_read2(assert_equal, assert_raises): +def test_read2(): cdef BytesIOReader reader = BytesIOReader(b'abcdef') reader.read(5) reader.read(1) -def test_read3(assert_equal, assert_raises): +def test_read3(): cdef BytesIOReader reader = BytesIOReader(b'abcdef') reader.read(6) -def test_read_eof(assert_equal, assert_raises): +def test_read_eof(): cdef BytesIOReader reader = BytesIOReader(b'abcdef') reader.read(5) - # cannot convert reader.read to an object, do it manually - # assert_raises(EOFError, reader.read, 2) - try: + with pytest.raises(EOFError): reader.read(2) - except EOFError: - pass - else: - raise Exception("Expected an EOFError") reader.read(1) # see that we can still read this diff --git a/tests/unit/cython/test_bytesio.py b/tests/unit/cython/test_bytesio.py index cd4ea86f52..0f27663391 100644 --- a/tests/unit/cython/test_bytesio.py +++ b/tests/unit/cython/test_bytesio.py @@ -23,10 +23,10 @@ class BytesIOTest(unittest.TestCase): @cythontest def test_reading(self): - bytesio_testhelper.test_read1(self.assertEqual, self.assertRaises) - bytesio_testhelper.test_read2(self.assertEqual, self.assertRaises) - bytesio_testhelper.test_read3(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read1() + bytesio_testhelper.test_read2() + bytesio_testhelper.test_read3() @cythontest def test_reading_error(self): - bytesio_testhelper.test_read_eof(self.assertEqual, self.assertRaises) + bytesio_testhelper.test_read_eof() diff --git a/tests/unit/cython/test_types.py b/tests/unit/cython/test_types.py index 545b82fc11..996be266c0 100644 --- a/tests/unit/cython/test_types.py +++ b/tests/unit/cython/test_types.py @@ -22,8 +22,8 @@ class TypesTest(unittest.TestCase): @cythontest def test_datetype(self): - types_testhelper.test_datetype(self.assertEqual) + types_testhelper.test_datetype() @cythontest def test_date_side_by_side(self): - types_testhelper.test_date_side_by_side(self.assertEqual) + types_testhelper.test_date_side_by_side() diff --git a/tests/unit/cython/test_utils.py b/tests/unit/cython/test_utils.py index 0e79c235d8..02e4f0590f 100644 --- a/tests/unit/cython/test_utils.py +++ b/tests/unit/cython/test_utils.py @@ -23,4 +23,4 @@ class UtilsTest(unittest.TestCase): @cythontest def test_datetime_from_timestamp(self): - utils_testhelper.test_datetime_from_timestamp(self.assertEqual) + utils_testhelper.test_datetime_from_timestamp() diff --git a/tests/unit/cython/types_testhelper.pyx b/tests/unit/cython/types_testhelper.pyx index 66d2516319..81f9dca114 100644 --- a/tests/unit/cython/types_testhelper.pyx +++ b/tests/unit/cython/types_testhelper.pyx @@ -28,7 +28,7 @@ from cassandra.buffer cimport Buffer from cassandra.deserializers cimport from_binary, Deserializer -def test_datetype(assert_equal): +def test_datetype(): cdef Deserializer des = find_deserializer(DateType) @@ -52,27 +52,27 @@ def test_datetype(assert_equal): # deserialize # epoc expected = 0 - assert_equal(deserialize(expected), datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None)) + assert deserialize(expected) == datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None) # beyond 32b expected = 2 ** 33 - assert_equal(deserialize(expected), datetime.datetime(2242, 3, 16, 12, 56, 32)) + assert deserialize(expected) == datetime.datetime(2242, 3, 16, 12, 56, 32) # less than epoc (PYTHON-119) expected = -770172256 - assert_equal(deserialize(expected), datetime.datetime(1945, 8, 5, 23, 15, 44)) + assert deserialize(expected) == datetime.datetime(1945, 8, 5, 23, 15, 44) # work around rounding difference among Python versions (PYTHON-230) # This wont pass with the cython extension until we fix the microseconds alignment with CPython #expected = 1424817268.274 - #assert_equal(deserialize(expected), datetime.datetime(2015, 2, 24, 22, 34, 28, 274000)) + #assert deserialize(expected) == datetime.datetime(2015, 2, 24, 22, 34, 28, 274000) # Large date overflow (PYTHON-452) expected = 2177403010.123 - assert_equal(deserialize(expected), datetime.datetime(2038, 12, 31, 10, 10, 10, 123000)) + assert deserialize(expected) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123000) -def test_date_side_by_side(assert_equal): +def test_date_side_by_side(): # Test pure python and cython date deserialization side-by-side # This is meant to detect inconsistent rounding or conversion (PYTHON-480 for example) # The test covers the full range of time deserializable in Python. It bounds through @@ -91,7 +91,7 @@ def test_date_side_by_side(assert_equal): buf.size = bior.size cython_deserialized = from_binary(cython_deserializer, &buf, 0) python_deserialized = DateType.deserialize(blob, 0) - assert_equal(cython_deserialized, python_deserialized) + assert cython_deserialized == python_deserialized # min -> 0 x = int(calendar.timegm(datetime.datetime(1, 1, 1).utctimetuple()) * 1000) diff --git a/tests/unit/cython/utils_testhelper.pyx b/tests/unit/cython/utils_testhelper.pyx index fe67691aa8..e997caa89e 100644 --- a/tests/unit/cython/utils_testhelper.pyx +++ b/tests/unit/cython/utils_testhelper.pyx @@ -17,7 +17,7 @@ import datetime from cassandra.cython_utils cimport datetime_from_timestamp -def test_datetime_from_timestamp(assert_equal): - assert_equal(datetime_from_timestamp(1454781157.123456), datetime.datetime(2016, 2, 6, 17, 52, 37, 123456)) +def test_datetime_from_timestamp(): + assert datetime_from_timestamp(1454781157.123456) == datetime.datetime(2016, 2, 6, 17, 52, 37, 123456) # PYTHON-452 - assert_equal(datetime_from_timestamp(2177403010.123456), datetime.datetime(2038, 12, 31, 10, 10, 10, 123456)) + assert datetime_from_timestamp(2177403010.123456) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123456) diff --git a/tests/unit/io/test_asyncioreactor.py b/tests/unit/io/test_asyncioreactor.py index a6179e122d..c189aa3d74 100644 --- a/tests/unit/io/test_asyncioreactor.py +++ b/tests/unit/io/test_asyncioreactor.py @@ -74,4 +74,4 @@ def test_timer_cancellation(self): # Release context allow for timer thread to run. time.sleep(.2) # Assert that the cancellation was honored - self.assertFalse(callback.was_invoked()) + assert not callback.was_invoked() diff --git a/tests/unit/io/test_libevreactor.py b/tests/unit/io/test_libevreactor.py index 17e03d0fd5..cf7e7caf77 100644 --- a/tests/unit/io/test_libevreactor.py +++ b/tests/unit/io/test_libevreactor.py @@ -83,8 +83,8 @@ def test_watchers_are_finished(self): # be called libev__cleanup(_global_loop) for conn in live_connections: - self.assertTrue(conn._write_watcher.stop.mock_calls) - self.assertTrue(conn._read_watcher.stop.mock_calls) + assert conn._write_watcher.stop.mock_calls + assert conn._read_watcher.stop.mock_calls _global_loop._shutdown = False diff --git a/tests/unit/io/test_twistedreactor.py b/tests/unit/io/test_twistedreactor.py index fd17d8454f..54abe884ae 100644 --- a/tests/unit/io/test_twistedreactor.py +++ b/tests/unit/io/test_twistedreactor.py @@ -76,7 +76,7 @@ def test_makeConnection(self): object that a successful connection was made. """ self.obj_ut.makeConnection(self.tr) - self.assertTrue(self.mock_connection.client_connection_made.called) + assert self.mock_connection.client_connection_made.called def test_receiving_data(self): """ @@ -85,7 +85,7 @@ def test_receiving_data(self): """ self.obj_ut.makeConnection(self.tr) self.obj_ut.dataReceived('foobar') - self.assertTrue(self.mock_connection.handle_read.called) + assert self.mock_connection.handle_read.called self.mock_connection._iobuf.write.assert_called_with("foobar") @@ -136,36 +136,35 @@ def test_close(self, mock_connectTCP): self.obj_ut.is_closed = False self.obj_ut.close() - self.assertTrue(self.obj_ut.connected_event.is_set()) - self.assertTrue(self.obj_ut.error_all_requests.called) + assert self.obj_ut.connected_event.is_set() + assert self.obj_ut.error_all_requests.called def test_handle_read__incomplete(self): """ Verify that handle_read() processes incomplete messages properly. """ self.obj_ut.process_msg = Mock() - self.assertEqual(self.obj_ut._iobuf.getvalue(), b'') # buf starts empty + assert self.obj_ut._iobuf.getvalue() == b'' # buf starts empty # incomplete header self.obj_ut._iobuf.write(b'\x84\x00\x00\x00\x00') self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), b'\x84\x00\x00\x00\x00') + assert self.obj_ut._io_buffer.cql_frame_buffer.getvalue() == b'\x84\x00\x00\x00\x00' # full header, but incomplete body self.obj_ut._iobuf.write(b'\x00\x00\x00\x15') self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), - b'\x84\x00\x00\x00\x00\x00\x00\x00\x15') - self.assertEqual(self.obj_ut._current_frame.end_pos, 30) + assert self.obj_ut._io_buffer.cql_frame_buffer.getvalue() == b'\x84\x00\x00\x00\x00\x00\x00\x00\x15' + assert self.obj_ut._current_frame.end_pos == 30 # verify we never attempted to process the incomplete message - self.assertFalse(self.obj_ut.process_msg.called) + assert not self.obj_ut.process_msg.called def test_handle_read__fullmessage(self): """ Verify that handle_read() processes complete messages properly. """ self.obj_ut.process_msg = Mock() - self.assertEqual(self.obj_ut._iobuf.getvalue(), b'') # buf starts empty + assert self.obj_ut._iobuf.getvalue() == b'' # buf starts empty # write a complete message, plus 'NEXT' (to simulate next message) # assumes protocol v3+ as default Connection.protocol_version @@ -174,7 +173,7 @@ def test_handle_read__fullmessage(self): self.obj_ut._iobuf.write( b'\x84\x01\x00\x02\x03\x00\x00\x00\x15' + body + extra) self.obj_ut.handle_read() - self.assertEqual(self.obj_ut._io_buffer.cql_frame_buffer.getvalue(), extra) + assert self.obj_ut._io_buffer.cql_frame_buffer.getvalue() == extra self.obj_ut.process_msg.assert_called_with( _Frame(version=4, flags=1, stream=2, opcode=3, body_offset=9, end_pos=9 + len(body)), body) diff --git a/tests/unit/io/utils.py b/tests/unit/io/utils.py index fa9017ffa2..174137225a 100644 --- a/tests/unit/io/utils.py +++ b/tests/unit/io/utils.py @@ -36,6 +36,7 @@ from socket import error as socket_error import ssl import time +import pytest log = logging.getLogger(__name__) @@ -128,7 +129,7 @@ def submit_and_wait_for_completion(unit_test, create_timer, start, end, incremen # ensure they are all called back in a timely fashion for callback in completed_callbacks: - unit_test.assertAlmostEqual(callback.expected_wait, callback.get_wait_time(), delta=.15) + assert callback.expected_wait == pytest.approx(callback.get_wait_time(), abs=.15) def noop_if_monkey_patched(f): @@ -181,9 +182,9 @@ def test_timer_cancellation(self): time.sleep(timeout * 2) timer_manager = self._timers # Assert that the cancellation was honored - self.assertFalse(timer_manager._queue) - self.assertFalse(timer_manager._new_timers) - self.assertFalse(callback.was_invoked()) + assert not timer_manager._queue + assert not timer_manager._new_timers + assert not callback.was_invoked() class ReactorTestMixin(object): @@ -249,7 +250,7 @@ def test_successful_connection(self): self.get_socket(c).recv.return_value = self.make_msg(header) c.handle_read(*self.null_handle_function_args) - self.assertTrue(c.connected_event.is_set()) + assert c.connected_event.is_set() return c def test_eagain_on_buffer_size(self): @@ -310,7 +311,7 @@ def chunk(size): # Ensure the message size is the good one and that the # message has been processed if it is non-empty - self.assertEqual(c._io_buffer.io_buffer.tell(), expected_size) + assert c._io_buffer.io_buffer.tell() == expected_size if expected_size == 0: c.process_io_buffer.assert_not_called() else: @@ -329,9 +330,9 @@ def test_protocol_error(self): c.handle_read(*self.null_handle_function_args) # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertTrue(c.connected_event.is_set()) - self.assertIsInstance(c.last_error, ProtocolError) + assert c.is_defunct + assert c.connected_event.is_set() + assert isinstance(c.last_error, ProtocolError) def test_error_message_on_startup(self): c = self.make_connection() @@ -354,9 +355,9 @@ def test_error_message_on_startup(self): c.handle_read(*self.null_handle_function_args) # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, ConnectionException) - self.assertTrue(c.connected_event.is_set()) + assert c.is_defunct + assert isinstance(c.last_error, ConnectionException) + assert c.connected_event.is_set() def test_socket_error_on_write(self): c = self.make_connection() @@ -366,9 +367,9 @@ def test_socket_error_on_write(self): c.handle_write(*self.null_handle_function_args) # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) + assert c.is_defunct + assert isinstance(c.last_error, socket_error) + assert c.connected_event.is_set() def test_blocking_on_write(self): c = self.make_connection() @@ -378,13 +379,13 @@ def test_blocking_on_write(self): "socket busy") c.handle_write(*self.null_handle_function_args) - self.assertFalse(c.is_defunct) + assert not c.is_defunct # try again with normal behavior self.get_socket(c).send.side_effect = lambda x: len(x) c.handle_write(*self.null_handle_function_args) - self.assertFalse(c.is_defunct) - self.assertTrue(self.get_socket(c).send.call_args is not None) + assert not c.is_defunct + assert self.get_socket(c).send.call_args is not None def test_partial_send(self): c = self.make_connection() @@ -399,10 +400,9 @@ def test_partial_send(self): expected_writes = int(math.ceil(float(msg_size) / write_size)) size_mod = msg_size % write_size last_write_size = size_mod if size_mod else write_size - self.assertFalse(c.is_defunct) - self.assertEqual(expected_writes, self.get_socket(c).send.call_count) - self.assertEqual(last_write_size, - len(self.get_socket(c).send.call_args[0][0])) + assert not c.is_defunct + assert expected_writes == self.get_socket(c).send.call_count + assert last_write_size == len(self.get_socket(c).send.call_args[0][0]) def test_socket_error_on_read(self): c = self.make_connection() @@ -416,9 +416,9 @@ def test_socket_error_on_read(self): c.handle_read(*self.null_handle_function_args) # make sure it errored correctly - self.assertTrue(c.is_defunct) - self.assertIsInstance(c.last_error, socket_error) - self.assertTrue(c.connected_event.is_set()) + assert c.is_defunct + assert isinstance(c.last_error, socket_error) + assert c.connected_event.is_set() def test_partial_header_read(self): c = self.make_connection() @@ -429,11 +429,11 @@ def test_partial_header_read(self): self.get_socket(c).recv.return_value = message[0:1] c.handle_read(*self.null_handle_function_args) - self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[0:1]) + assert c._io_buffer.cql_frame_buffer.getvalue() == message[0:1] self.get_socket(c).recv.return_value = message[1:] c.handle_read(*self.null_handle_function_args) - self.assertEqual(bytes(), c._io_buffer.io_buffer.getvalue()) + assert bytes() == c._io_buffer.io_buffer.getvalue() # let it write out a StartupMessage c.handle_write(*self.null_handle_function_args) @@ -442,8 +442,8 @@ def test_partial_header_read(self): self.get_socket(c).recv.return_value = self.make_msg(header) c.handle_read(*self.null_handle_function_args) - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) + assert c.connected_event.is_set() + assert not c.is_defunct def test_partial_message_read(self): c = self.make_connection() @@ -455,12 +455,12 @@ def test_partial_message_read(self): # read in the first nine bytes self.get_socket(c).recv.return_value = message[:9] c.handle_read(*self.null_handle_function_args) - self.assertEqual(c._io_buffer.cql_frame_buffer.getvalue(), message[:9]) + assert c._io_buffer.cql_frame_buffer.getvalue() == message[:9] # ... then read in the rest self.get_socket(c).recv.return_value = message[9:] c.handle_read(*self.null_handle_function_args) - self.assertEqual(bytes(), c._io_buffer.io_buffer.getvalue()) + assert bytes() == c._io_buffer.io_buffer.getvalue() # let it write out a StartupMessage c.handle_write(*self.null_handle_function_args) @@ -469,8 +469,8 @@ def test_partial_message_read(self): self.get_socket(c).recv.return_value = self.make_msg(header) c.handle_read(*self.null_handle_function_args) - self.assertTrue(c.connected_event.is_set()) - self.assertFalse(c.is_defunct) + assert c.connected_event.is_set() + assert not c.is_defunct def test_mixed_message_and_buffer_sizes(self): """ diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 0a2427c7ff..776cbd6973 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -22,7 +22,4 @@ class TestPlainTextAuthenticator(unittest.TestCase): def test_evaluate_challenge_with_unicode_data(self): authenticator = PlainTextAuthenticator("johnӁ", "doeӁ") - self.assertEqual( - authenticator.evaluate_challenge(b'PLAIN-START'), - "\x00johnӁ\x00doeӁ".encode('utf-8') - ) + assert authenticator.evaluate_challenge(b'PLAIN-START') == "\x00johnӁ\x00doeӁ".encode('utf-8') diff --git a/tests/unit/test_cluster.py b/tests/unit/test_cluster.py index e656dad005..180f4b6a8a 100644 --- a/tests/unit/test_cluster.py +++ b/tests/unit/test_cluster.py @@ -27,6 +27,7 @@ from cassandra.query import SimpleStatement, named_tuple_factory, tuple_factory from tests.unit.utils import mock_session_pools from tests import connection_class +import pytest log = logging.getLogger(__name__) @@ -38,51 +39,51 @@ def test_exception_types(self): PYTHON-443 Sanity check to ensure we don't unintentionally change class hierarchy of exception types """ - self.assertTrue(issubclass(Unavailable, DriverException)) - self.assertTrue(issubclass(Unavailable, RequestExecutionException)) + assert issubclass(Unavailable, DriverException) + assert issubclass(Unavailable, RequestExecutionException) - self.assertTrue(issubclass(ReadTimeout, DriverException)) - self.assertTrue(issubclass(ReadTimeout, RequestExecutionException)) - self.assertTrue(issubclass(ReadTimeout, Timeout)) + assert issubclass(ReadTimeout, DriverException) + assert issubclass(ReadTimeout, RequestExecutionException) + assert issubclass(ReadTimeout, Timeout) - self.assertTrue(issubclass(WriteTimeout, DriverException)) - self.assertTrue(issubclass(WriteTimeout, RequestExecutionException)) - self.assertTrue(issubclass(WriteTimeout, Timeout)) + assert issubclass(WriteTimeout, DriverException) + assert issubclass(WriteTimeout, RequestExecutionException) + assert issubclass(WriteTimeout, Timeout) - self.assertTrue(issubclass(CoordinationFailure, DriverException)) - self.assertTrue(issubclass(CoordinationFailure, RequestExecutionException)) + assert issubclass(CoordinationFailure, DriverException) + assert issubclass(CoordinationFailure, RequestExecutionException) - self.assertTrue(issubclass(ReadFailure, DriverException)) - self.assertTrue(issubclass(ReadFailure, RequestExecutionException)) - self.assertTrue(issubclass(ReadFailure, CoordinationFailure)) + assert issubclass(ReadFailure, DriverException) + assert issubclass(ReadFailure, RequestExecutionException) + assert issubclass(ReadFailure, CoordinationFailure) - self.assertTrue(issubclass(WriteFailure, DriverException)) - self.assertTrue(issubclass(WriteFailure, RequestExecutionException)) - self.assertTrue(issubclass(WriteFailure, CoordinationFailure)) + assert issubclass(WriteFailure, DriverException) + assert issubclass(WriteFailure, RequestExecutionException) + assert issubclass(WriteFailure, CoordinationFailure) - self.assertTrue(issubclass(FunctionFailure, DriverException)) - self.assertTrue(issubclass(FunctionFailure, RequestExecutionException)) + assert issubclass(FunctionFailure, DriverException) + assert issubclass(FunctionFailure, RequestExecutionException) - self.assertTrue(issubclass(RequestValidationException, DriverException)) + assert issubclass(RequestValidationException, DriverException) - self.assertTrue(issubclass(ConfigurationException, DriverException)) - self.assertTrue(issubclass(ConfigurationException, RequestValidationException)) + assert issubclass(ConfigurationException, DriverException) + assert issubclass(ConfigurationException, RequestValidationException) - self.assertTrue(issubclass(AlreadyExists, DriverException)) - self.assertTrue(issubclass(AlreadyExists, RequestValidationException)) - self.assertTrue(issubclass(AlreadyExists, ConfigurationException)) + assert issubclass(AlreadyExists, DriverException) + assert issubclass(AlreadyExists, RequestValidationException) + assert issubclass(AlreadyExists, ConfigurationException) - self.assertTrue(issubclass(InvalidRequest, DriverException)) - self.assertTrue(issubclass(InvalidRequest, RequestValidationException)) + assert issubclass(InvalidRequest, DriverException) + assert issubclass(InvalidRequest, RequestValidationException) - self.assertTrue(issubclass(Unauthorized, DriverException)) - self.assertTrue(issubclass(Unauthorized, RequestValidationException)) + assert issubclass(Unauthorized, DriverException) + assert issubclass(Unauthorized, RequestValidationException) - self.assertTrue(issubclass(AuthenticationFailed, DriverException)) + assert issubclass(AuthenticationFailed, DriverException) - self.assertTrue(issubclass(OperationTimedOut, DriverException)) + assert issubclass(OperationTimedOut, DriverException) - self.assertTrue(issubclass(UnsupportedOperation, DriverException)) + assert issubclass(UnsupportedOperation, DriverException) class ClusterTest(unittest.TestCase): @@ -92,17 +93,17 @@ def test_tuple_for_contact_points(self): localhost_addr = set([addr[0] for addr in [t for (_,_,_,_,t) in socket.getaddrinfo("localhost",80)]]) for cp in cluster.endpoints_resolved: if cp.address in localhost_addr: - self.assertEqual(cp.port, 9045) + assert cp.port == 9045 elif cp.address == '127.0.0.2': - self.assertEqual(cp.port, 9046) + assert cp.port == 9046 else: - self.assertEqual(cp.address, '127.0.0.3') - self.assertEqual(cp.port, 9999) + assert cp.address == '127.0.0.3' + assert cp.port == 9999 def test_invalid_contact_point_types(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Cluster(contact_points=[None], protocol_version=4, connect_timeout=1) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): Cluster(contact_points="not a sequence", protocol_version=4, connect_timeout=1) def test_port_str(self): @@ -110,15 +111,15 @@ def test_port_str(self): cluster = Cluster(contact_points=['127.0.0.1'], port='1111') for cp in cluster.endpoints_resolved: if cp.address in ('::1', '127.0.0.1'): - self.assertEqual(cp.port, 1111) + assert cp.port == 1111 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port='string') def test_port_range(self): for invalid_port in [0, 65536, -1]: - with self.assertRaises(ValueError): + with pytest.raises(ValueError): cluster = Cluster(contact_points=['127.0.0.1'], port=invalid_port) @@ -159,20 +160,20 @@ def test_default_serial_consistency_level_ep(self, *_): # default is None default_profile = c.profile_manager.default - self.assertIsNone(default_profile.serial_consistency_level) + assert default_profile.serial_consistency_level is None for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): s.get_execution_profile(EXEC_PROFILE_DEFAULT).serial_consistency_level = cl # default is passed through f = s.execute_async(query='') - self.assertEqual(f.message.serial_consistency_level, cl) + assert f.message.serial_consistency_level == cl # any non-None statement setting takes precedence for cl_override in (ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): f = s.execute_async(SimpleStatement(query_string='', serial_consistency_level=cl_override)) - self.assertEqual(default_profile.serial_consistency_level, cl) - self.assertEqual(f.message.serial_consistency_level, cl_override) + assert default_profile.serial_consistency_level == cl + assert f.message.serial_consistency_level == cl_override @mock_session_pools def test_default_serial_consistency_level_legacy(self, *_): @@ -186,12 +187,12 @@ def test_default_serial_consistency_level_legacy(self, *_): s = Session(c, [Host("127.0.0.1", SimpleConvictionPolicy)]) c.connection_class.initialize_reactor() # default is None - self.assertIsNone(s.default_serial_consistency_level) + assert s.default_serial_consistency_level is None # Should fail - with self.assertRaises(ValueError): + with pytest.raises(ValueError): s.default_serial_consistency_level = ConsistencyLevel.ANY - with self.assertRaises(ValueError): + with pytest.raises(ValueError): s.default_serial_consistency_level = 1001 for cl in (None, ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): @@ -200,29 +201,29 @@ def test_default_serial_consistency_level_legacy(self, *_): # any non-None statement setting takes precedence for cl_override in (ConsistencyLevel.LOCAL_SERIAL, ConsistencyLevel.SERIAL): f = s.execute_async(SimpleStatement(query_string='', serial_consistency_level=cl_override)) - self.assertEqual(s.default_serial_consistency_level, cl) - self.assertEqual(f.message.serial_consistency_level, cl_override) + assert s.default_serial_consistency_level == cl + assert f.message.serial_consistency_level == cl_override class ProtocolVersionTests(unittest.TestCase): def test_protocol_downgrade_test(self): lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V2) - self.assertEqual(ProtocolVersion.DSE_V1, lower) + assert ProtocolVersion.DSE_V1 == lower lower = ProtocolVersion.get_lower_supported(ProtocolVersion.DSE_V1) - self.assertEqual(ProtocolVersion.V5,lower) + assert ProtocolVersion.V5 == lower lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V5) - self.assertEqual(ProtocolVersion.V4,lower) + assert ProtocolVersion.V4 == lower lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V4) - self.assertEqual(ProtocolVersion.V3,lower) + assert ProtocolVersion.V3 == lower lower = ProtocolVersion.get_lower_supported(ProtocolVersion.V3) - self.assertEqual(0, lower) + assert 0 == lower - self.assertTrue(ProtocolVersion.uses_error_code_map(ProtocolVersion.DSE_V1)) - self.assertTrue(ProtocolVersion.uses_int_query_flags(ProtocolVersion.DSE_V1)) + assert ProtocolVersion.uses_error_code_map(ProtocolVersion.DSE_V1) + assert ProtocolVersion.uses_int_query_flags(ProtocolVersion.DSE_V1) - self.assertFalse(ProtocolVersion.uses_error_code_map(ProtocolVersion.V4)) - self.assertFalse(ProtocolVersion.uses_int_query_flags(ProtocolVersion.V4)) + assert not ProtocolVersion.uses_error_code_map(ProtocolVersion.V4) + assert not ProtocolVersion.uses_int_query_flags(ProtocolVersion.V4) class ExecutionProfileTest(unittest.TestCase): @@ -232,35 +233,35 @@ def setUp(self): connection_class.initialize_reactor() def _verify_response_future_profile(self, rf, prof): - self.assertEqual(rf._load_balancer, prof.load_balancing_policy) - self.assertEqual(rf._retry_policy, prof.retry_policy) - self.assertEqual(rf.message.consistency_level, prof.consistency_level) - self.assertEqual(rf.message.serial_consistency_level, prof.serial_consistency_level) - self.assertEqual(rf.timeout, prof.request_timeout) - self.assertEqual(rf.row_factory, prof.row_factory) + assert rf._load_balancer == prof.load_balancing_policy + assert rf._retry_policy == prof.retry_policy + assert rf.message.consistency_level == prof.consistency_level + assert rf.message.serial_consistency_level == prof.serial_consistency_level + assert rf.timeout == prof.request_timeout + assert rf.row_factory == prof.row_factory @mock_session_pools def test_default_exec_parameters(self): cluster = Cluster() - self.assertEqual(cluster._config_mode, _ConfigMode.UNCOMMITTED) - self.assertEqual(cluster.load_balancing_policy.__class__, default_lbp_factory().__class__) - self.assertEqual(cluster.profile_manager.default.load_balancing_policy.__class__, default_lbp_factory().__class__) - self.assertEqual(cluster.default_retry_policy.__class__, RetryPolicy) - self.assertEqual(cluster.profile_manager.default.retry_policy.__class__, RetryPolicy) + assert cluster._config_mode == _ConfigMode.UNCOMMITTED + assert cluster.load_balancing_policy.__class__ == default_lbp_factory().__class__ + assert cluster.profile_manager.default.load_balancing_policy.__class__ == default_lbp_factory().__class__ + assert cluster.default_retry_policy.__class__ == RetryPolicy + assert cluster.profile_manager.default.retry_policy.__class__ == RetryPolicy session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) - self.assertEqual(session.default_timeout, 10.0) - self.assertEqual(cluster.profile_manager.default.request_timeout, 10.0) - self.assertEqual(session.default_consistency_level, ConsistencyLevel.LOCAL_ONE) - self.assertEqual(cluster.profile_manager.default.consistency_level, ConsistencyLevel.LOCAL_ONE) - self.assertEqual(session.default_serial_consistency_level, None) - self.assertEqual(cluster.profile_manager.default.serial_consistency_level, None) - self.assertEqual(session.row_factory, named_tuple_factory) - self.assertEqual(cluster.profile_manager.default.row_factory, named_tuple_factory) + assert session.default_timeout == 10.0 + assert cluster.profile_manager.default.request_timeout == 10.0 + assert session.default_consistency_level == ConsistencyLevel.LOCAL_ONE + assert cluster.profile_manager.default.consistency_level == ConsistencyLevel.LOCAL_ONE + assert session.default_serial_consistency_level == None + assert cluster.profile_manager.default.serial_consistency_level == None + assert session.row_factory == named_tuple_factory + assert cluster.profile_manager.default.row_factory == named_tuple_factory @mock_session_pools def test_default_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) - self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + assert cluster._config_mode == _ConfigMode.LEGACY session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) session.default_timeout = 3.7 session.default_consistency_level = ConsistencyLevel.ALL @@ -277,7 +278,7 @@ def test_default_profile(self): cluster = Cluster(execution_profiles={'non-default': non_default_profile}) session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) - self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + assert cluster._config_mode == _ConfigMode.PROFILES default_profile = cluster.profile_manager.profiles[EXEC_PROFILE_DEFAULT] rf = session.execute_async("query") @@ -287,10 +288,10 @@ def test_default_profile(self): self._verify_response_future_profile(rf, non_default_profile) for name, ep in cluster.profile_manager.profiles.items(): - self.assertEqual(ep, session.get_execution_profile(name)) + assert ep == session.get_execution_profile(name) # invalid ep - with self.assertRaises(ValueError): + with pytest.raises(ValueError): session.get_execution_profile('non-existent') def test_serial_consistency_level_validation(self): @@ -299,25 +300,25 @@ def test_serial_consistency_level_validation(self): ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.LOCAL_SERIAL) # should not pass - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=ConsistencyLevel.ANY) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): ep = ExecutionProfile(RoundRobinPolicy(), serial_consistency_level=42) @mock_session_pools def test_statement_params_override_legacy(self): cluster = Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy()) - self.assertEqual(cluster._config_mode, _ConfigMode.LEGACY) + assert cluster._config_mode == _ConfigMode.LEGACY session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) ss = SimpleStatement("query", retry_policy=DowngradingConsistencyRetryPolicy(), consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) my_timeout = 1.1234 - self.assertNotEqual(ss.retry_policy.__class__, cluster.default_retry_policy) - self.assertNotEqual(ss.consistency_level, session.default_consistency_level) - self.assertNotEqual(ss._serial_consistency_level, session.default_serial_consistency_level) - self.assertNotEqual(my_timeout, session.default_timeout) + assert ss.retry_policy.__class__ != cluster.default_retry_policy + assert ss.consistency_level != session.default_consistency_level + assert ss._serial_consistency_level != session.default_serial_consistency_level + assert my_timeout != session.default_timeout rf = session.execute_async(ss, timeout=my_timeout) expected_profile = ExecutionProfile(load_balancing_policy=cluster.load_balancing_policy, retry_policy=ss.retry_policy, @@ -331,7 +332,7 @@ def test_statement_params_override_profile(self): cluster = Cluster(execution_profiles={'non-default': non_default_profile}) session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) - self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + assert cluster._config_mode == _ConfigMode.PROFILES rf = session.execute_async("query", execution_profile='non-default') @@ -339,10 +340,10 @@ def test_statement_params_override_profile(self): consistency_level=ConsistencyLevel.ALL, serial_consistency_level=ConsistencyLevel.SERIAL) my_timeout = 1.1234 - self.assertNotEqual(ss.retry_policy.__class__, rf._load_balancer.__class__) - self.assertNotEqual(ss.consistency_level, rf.message.consistency_level) - self.assertNotEqual(ss._serial_consistency_level, rf.message.serial_consistency_level) - self.assertNotEqual(my_timeout, rf.timeout) + assert ss.retry_policy.__class__ != rf._load_balancer.__class__ + assert ss.consistency_level != rf.message.consistency_level + assert ss._serial_consistency_level != rf.message.serial_consistency_level + assert my_timeout != rf.timeout rf = session.execute_async(ss, timeout=my_timeout, execution_profile='non-default') expected_profile = ExecutionProfile(non_default_profile.load_balancing_policy, ss.retry_policy, @@ -352,14 +353,18 @@ def test_statement_params_override_profile(self): @mock_session_pools def test_no_profile_with_legacy(self): # don't construct with both - self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), execution_profiles={'a': ExecutionProfile()}) - self.assertRaises(ValueError, Cluster, default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) - self.assertRaises(ValueError, Cluster, load_balancing_policy=RoundRobinPolicy(), + with pytest.raises(ValueError): + Cluster(load_balancing_policy=RoundRobinPolicy(), execution_profiles={'a': ExecutionProfile()}) + with pytest.raises(ValueError): + Cluster(default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) + with pytest.raises(ValueError): + Cluster(load_balancing_policy=RoundRobinPolicy(), default_retry_policy=DowngradingConsistencyRetryPolicy(), execution_profiles={'a': ExecutionProfile()}) # can't add after cluster = Cluster(load_balancing_policy=RoundRobinPolicy()) - self.assertRaises(ValueError, cluster.add_execution_profile, 'name', ExecutionProfile()) + with pytest.raises(ValueError): + cluster.add_execution_profile('name', ExecutionProfile()) # session settings lock out profiles cluster = Cluster() @@ -370,10 +375,12 @@ def test_no_profile_with_legacy(self): ('row_factory', tuple_factory)): cluster._config_mode = _ConfigMode.UNCOMMITTED setattr(session, attr, value) - self.assertRaises(ValueError, cluster.add_execution_profile, 'name' + attr, ExecutionProfile()) + with pytest.raises(ValueError): + cluster.add_execution_profile('name' + attr, ExecutionProfile()) # don't accept profile - self.assertRaises(ValueError, session.execute_async, "query", execution_profile='some name here') + with pytest.raises(ValueError): + session.execute_async("query", execution_profile='some name here') @mock_session_pools def test_no_legacy_with_profile(self): @@ -385,13 +392,15 @@ def test_no_legacy_with_profile(self): # don't allow legacy parameters set for attr, value in (('default_retry_policy', RetryPolicy()), ('load_balancing_policy', default_lbp_factory())): - self.assertRaises(ValueError, setattr, cluster, attr, value) + with pytest.raises(ValueError): + setattr(cluster, attr, value) session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) for attr, value in (('default_timeout', 1), ('default_consistency_level', ConsistencyLevel.ANY), ('default_serial_consistency_level', ConsistencyLevel.SERIAL), ('row_factory', tuple_factory)): - self.assertRaises(ValueError, setattr, session, attr, value) + with pytest.raises(ValueError): + setattr(session, attr, value) @mock_session_pools def test_profile_name_value(self): @@ -399,7 +408,7 @@ def test_profile_name_value(self): internalized_profile = ExecutionProfile(RoundRobinPolicy(), *[object() for _ in range(2)]) cluster = Cluster(execution_profiles={'by-name': internalized_profile}) session = Session(cluster, hosts=[Host("127.0.0.1", SimpleConvictionPolicy)]) - self.assertEqual(cluster._config_mode, _ConfigMode.PROFILES) + assert cluster._config_mode == _ConfigMode.PROFILES rf = session.execute_async("query", execution_profile='by-name') self._verify_response_future_profile(rf, internalized_profile) @@ -426,34 +435,38 @@ def test_exec_profile_clone(self): for profile in (EXEC_PROFILE_DEFAULT, 'one'): active = session.get_execution_profile(profile) clone = session.execution_profile_clone_update(profile) - self.assertIsNot(clone, active) + assert clone is not active all_updated = session.execution_profile_clone_update(clone, **profile_attrs) - self.assertIsNot(all_updated, clone) + assert all_updated is not clone for attr, value in profile_attrs.items(): - self.assertEqual(getattr(clone, attr), getattr(active, attr)) + assert getattr(clone, attr) == getattr(active, attr) if attr in reference_attributes: - self.assertIs(getattr(clone, attr), getattr(active, attr)) - self.assertNotEqual(getattr(all_updated, attr), getattr(active, attr)) + assert getattr(clone, attr) is getattr(active, attr) + assert getattr(all_updated, attr) != getattr(active, attr) # cannot clone nonexistent profile - self.assertRaises(ValueError, session.execution_profile_clone_update, 'DOES NOT EXIST', **profile_attrs) + with pytest.raises(ValueError): + session.execution_profile_clone_update('DOES NOT EXIST', **profile_attrs) def test_no_profiles_same_name(self): # can override default in init cluster = Cluster(execution_profiles={EXEC_PROFILE_DEFAULT: ExecutionProfile(), 'one': ExecutionProfile()}) # cannot update default - self.assertRaises(ValueError, cluster.add_execution_profile, EXEC_PROFILE_DEFAULT, ExecutionProfile()) + with pytest.raises(ValueError): + cluster.add_execution_profile(EXEC_PROFILE_DEFAULT, ExecutionProfile()) # cannot update named init - self.assertRaises(ValueError, cluster.add_execution_profile, 'one', ExecutionProfile()) + with pytest.raises(ValueError): + cluster.add_execution_profile('one', ExecutionProfile()) # can add new name cluster.add_execution_profile('two', ExecutionProfile()) # cannot add a profile added dynamically - self.assertRaises(ValueError, cluster.add_execution_profile, 'two', ExecutionProfile()) + with pytest.raises(ValueError): + cluster.add_execution_profile('two', ExecutionProfile()) def test_warning_on_no_lbp_with_contact_points_legacy_mode(self): """ @@ -493,8 +506,8 @@ def _check_warning_on_no_lbp_with_contact_points(self, cluster_kwargs): Cluster(**cluster_kwargs) patched_logger.warning.assert_called_once() warning_message = patched_logger.warning.call_args[0][0] - self.assertIn('please specify a load-balancing policy', warning_message) - self.assertIn("contact_points = ['127.0.0.1']", warning_message) + assert 'please specify a load-balancing policy' in warning_message + assert "contact_points = ['127.0.0.1']" in warning_message def test_no_warning_on_contact_points_with_lbp_legacy_mode(self): """ @@ -562,9 +575,9 @@ def test_warning_adding_no_lbp_ep_to_cluster_with_contact_points(self): patched_logger.warning.assert_called_once() warning_message = patched_logger.warning.call_args[0][0] - self.assertIn('no_lbp', warning_message) - self.assertIn('trying to add', warning_message) - self.assertIn('please specify a load-balancing policy', warning_message) + assert 'no_lbp' in warning_message + assert 'trying to add' in warning_message + assert 'please specify a load-balancing policy' in warning_message @mock_session_pools def test_no_warning_adding_lbp_ep_to_cluster_with_contact_points(self): diff --git a/tests/unit/test_concurrent.py b/tests/unit/test_concurrent.py index bdfd08126e..a3587a3e16 100644 --- a/tests/unit/test_concurrent.py +++ b/tests/unit/test_concurrent.py @@ -28,6 +28,7 @@ from cassandra.pool import Host from cassandra.policies import SimpleConvictionPolicy from tests.unit.utils import mock_session_pools +import pytest class MockResponseResponseFuture(): @@ -229,14 +230,14 @@ def validate_result_ordering(self, results): """ last_time_added = 0 for success, result in results: - self.assertTrue(success) + assert success current_time_added = list(result)[0] #Windows clock granularity makes this equal most of the times if "Windows" in platform.system(): - self.assertLessEqual(last_time_added, current_time_added) + assert last_time_added <= current_time_added else: - self.assertLess(last_time_added, current_time_added) + assert last_time_added < current_time_added last_time_added = current_time_added @mock_session_pools @@ -248,10 +249,11 @@ def test_recursion_limited(self): """ max_recursion = sys.getrecursionlimit() s = Session(Cluster(), [Host("127.0.0.1", SimpleConvictionPolicy)]) - self.assertRaises(TypeError, execute_concurrent_with_args, s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) + with pytest.raises(TypeError): + execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=True) results = execute_concurrent_with_args(s, "doesn't matter", [('param',)] * max_recursion, raise_on_first_error=False) # previously - self.assertEqual(len(results), max_recursion) + assert len(results) == max_recursion for r in results: - self.assertFalse(r[0]) - self.assertIsInstance(r[1], TypeError) + assert not r[0] + assert isinstance(r[1], TypeError) diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 17b23b5ed0..c21f5c7212 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -27,7 +27,8 @@ from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler) -from tests.util import wait_until +from tests.util import wait_until, assertRegex +import pytest class ConnectionTest(unittest.TestCase): @@ -67,17 +68,17 @@ def make_msg(self, header, body=""): def test_connection_endpoint(self): endpoint = DefaultEndPoint('1.2.3.4') c = Connection(endpoint) - self.assertEqual(c.endpoint, endpoint) - self.assertEqual(c.endpoint.address, endpoint.address) + assert c.endpoint == endpoint + assert c.endpoint.address == endpoint.address c = Connection(host=endpoint) # kwarg - self.assertEqual(c.endpoint, endpoint) - self.assertEqual(c.endpoint.address, endpoint.address) + assert c.endpoint == endpoint + assert c.endpoint.address == endpoint.address c = Connection('10.0.0.1') endpoint = DefaultEndPoint('10.0.0.1') - self.assertEqual(c.endpoint, endpoint) - self.assertEqual(c.endpoint.address, endpoint.address) + assert c.endpoint == endpoint + assert c.endpoint.address == endpoint.address def test_bad_protocol_version(self, *args): c = self.make_connection() @@ -95,7 +96,7 @@ def test_bad_protocol_version(self, *args): # make sure it errored correctly c.defunct.assert_called_once_with(ANY) args, kwargs = c.defunct.call_args - self.assertIsInstance(args[0], ProtocolError) + assert isinstance(args[0], ProtocolError) def test_negative_body_length(self, *args): c = self.make_connection() @@ -112,7 +113,7 @@ def test_negative_body_length(self, *args): # make sure it errored correctly c.defunct.assert_called_once_with(ANY) args, kwargs = c.defunct.call_args - self.assertIsInstance(args[0], ProtocolError) + assert isinstance(args[0], ProtocolError) def test_unsupported_cql_version(self, *args): c = self.make_connection() @@ -132,7 +133,7 @@ def test_unsupported_cql_version(self, *args): # make sure it errored correctly c.defunct.assert_called_once_with(ANY) args, kwargs = c.defunct.call_args - self.assertIsInstance(args[0], ProtocolError) + assert isinstance(args[0], ProtocolError) def test_prefer_lz4_compression(self, *args): c = self.make_connection() @@ -155,7 +156,7 @@ def test_prefer_lz4_compression(self, *args): c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) - self.assertEqual(c.decompressor, locally_supported_compressions['lz4'][1]) + assert c.decompressor == locally_supported_compressions['lz4'][1] def test_requested_compression_not_available(self, *args): c = self.make_connection() @@ -182,7 +183,7 @@ def test_requested_compression_not_available(self, *args): # make sure it errored correctly c.defunct.assert_called_once_with(ANY) args, kwargs = c.defunct.call_args - self.assertIsInstance(args[0], ProtocolError) + assert isinstance(args[0], ProtocolError) def test_use_requested_compression(self, *args): c = self.make_connection() @@ -206,7 +207,7 @@ def test_use_requested_compression(self, *args): c.process_msg(_Frame(version=4, flags=0, stream=0, opcode=SupportedMessage.opcode, body_offset=9, end_pos=9 + len(options)), options) - self.assertEqual(c.decompressor, locally_supported_compressions['snappy'][1]) + assert c.decompressor == locally_supported_compressions['snappy'][1] def test_disable_compression(self, *args): c = self.make_connection() @@ -234,29 +235,30 @@ def test_disable_compression(self, *args): message = self.make_msg(header, options) c.process_msg(message, len(message) - 8) - self.assertEqual(c.decompressor, None) + assert c.decompressor == None def test_not_implemented(self): """ Ensure the following methods throw NIE's. If not, come back and test them. """ c = self.make_connection() - self.assertRaises(NotImplementedError, c.close) + with pytest.raises(NotImplementedError): + c.close() def test_set_keyspace_blocking(self): c = self.make_connection() - self.assertEqual(c.keyspace, None) + assert c.keyspace == None c.set_keyspace_blocking(None) - self.assertEqual(c.keyspace, None) + assert c.keyspace == None c.keyspace = 'ks' c.set_keyspace_blocking('ks') - self.assertEqual(c.keyspace, 'ks') + assert c.keyspace == 'ks' def test_set_connection_class(self): cluster = Cluster(connection_class='test') - self.assertEqual('test', cluster.connection_class) + assert 'test' == cluster.connection_class @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') @@ -278,7 +280,7 @@ def run_heartbeat(self, get_holders_fun, count=2, interval=0.05, timeout=0.05): wait_until(lambda: get_holders_fun.call_count > 0, 0.01, 100) time.sleep(interval * (count-1)) ch.stop() - self.assertTrue(get_holders_fun.call_count) + assert get_holders_fun.call_count def test_empty_connections(self, *args): count = 3 @@ -286,8 +288,8 @@ def test_empty_connections(self, *args): self.run_heartbeat(get_holders, count) - self.assertGreaterEqual(get_holders.call_count, count-1) - self.assertLessEqual(get_holders.call_count, count) + assert get_holders.call_count >= count-1 + assert get_holders.call_count <= count holder = get_holders.return_value[0] holder.get_connections.assert_has_calls([call()] * get_holders.call_count) @@ -315,11 +317,11 @@ def send_msg(msg, req_id, msg_callback): self.run_heartbeat(get_holders) holder.get_connections.assert_has_calls([call()] * get_holders.call_count) - self.assertEqual(idle_connection.in_flight, 0) - self.assertEqual(non_idle_connection.in_flight, 0) + assert idle_connection.in_flight == 0 + assert non_idle_connection.in_flight == 0 idle_connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) - self.assertEqual(non_idle_connection.send_msg.call_count, 0) + assert non_idle_connection.send_msg.call_count == 0 def test_closed_defunct(self, *args): get_holders = self.make_get_holders(1) @@ -332,10 +334,10 @@ def test_closed_defunct(self, *args): self.run_heartbeat(get_holders) holder.get_connections.assert_has_calls([call()] * get_holders.call_count) - self.assertEqual(closed_connection.in_flight, 0) - self.assertEqual(defunct_connection.in_flight, 0) - self.assertEqual(closed_connection.send_msg.call_count, 0) - self.assertEqual(defunct_connection.send_msg.call_count, 0) + assert closed_connection.in_flight == 0 + assert defunct_connection.in_flight == 0 + assert closed_connection.send_msg.call_count == 0 + assert defunct_connection.send_msg.call_count == 0 def test_no_req_ids(self, *args): in_flight = 3 @@ -351,9 +353,9 @@ def test_no_req_ids(self, *args): self.run_heartbeat(get_holders) holder.get_connections.assert_has_calls([call()] * get_holders.call_count) - self.assertEqual(max_connection.in_flight, in_flight) - self.assertEqual(max_connection.send_msg.call_count, 0) - self.assertEqual(max_connection.send_msg.call_count, 0) + assert max_connection.in_flight == in_flight + assert max_connection.send_msg.call_count == 0 + assert max_connection.send_msg.call_count == 0 max_connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) holder.return_connection.assert_has_calls( [call(max_connection)] * get_holders.call_count) @@ -378,12 +380,12 @@ def send_msg(msg, req_id, msg_callback): self.run_heartbeat(get_holders) - self.assertEqual(connection.in_flight, get_holders.call_count) + assert connection.in_flight == get_holders.call_count connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, ConnectionException) - self.assertRegex(exc.args[0], r'^Received unexpected response to OptionsMessage.*') + assert isinstance(exc, ConnectionException) + assertRegex(exc.args[0], r'^Received unexpected response to OptionsMessage.*') holder.return_connection.assert_has_calls( [call(connection)] * get_holders.call_count) @@ -408,13 +410,13 @@ def send_msg(msg, req_id, msg_callback): self.run_heartbeat(get_holders) - self.assertEqual(connection.in_flight, get_holders.call_count) + assert connection.in_flight == get_holders.call_count connection.send_msg.assert_has_calls([call(ANY, request_id, ANY)] * get_holders.call_count) connection.defunct.assert_has_calls([call(ANY)] * get_holders.call_count) exc = connection.defunct.call_args_list[0][0][0] - self.assertIsInstance(exc, OperationTimedOut) - self.assertEqual(exc.errors, 'Connection heartbeat timeout after 0.05 seconds') - self.assertEqual(exc.last_host, DefaultEndPoint('localhost')) + assert isinstance(exc, OperationTimedOut) + assert exc.errors == 'Connection heartbeat timeout after 0.05 seconds' + assert exc.last_host == DefaultEndPoint('localhost') holder.return_connection.assert_has_calls( [call(connection)] * get_holders.call_count) @@ -439,46 +441,28 @@ class DefaultEndPointTest(unittest.TestCase): def test_default_endpoint_properties(self): endpoint = DefaultEndPoint('10.0.0.1') - self.assertEqual(endpoint.address, '10.0.0.1') - self.assertEqual(endpoint.port, 9042) - self.assertEqual(str(endpoint), '10.0.0.1:9042') + assert endpoint.address == '10.0.0.1' + assert endpoint.port == 9042 + assert str(endpoint) == '10.0.0.1:9042' endpoint = DefaultEndPoint('10.0.0.1', 8888) - self.assertEqual(endpoint.address, '10.0.0.1') - self.assertEqual(endpoint.port, 8888) - self.assertEqual(str(endpoint), '10.0.0.1:8888') + assert endpoint.address == '10.0.0.1' + assert endpoint.port == 8888 + assert str(endpoint) == '10.0.0.1:8888' def test_endpoint_equality(self): - self.assertEqual( - DefaultEndPoint('10.0.0.1'), - DefaultEndPoint('10.0.0.1') - ) - - self.assertEqual( - DefaultEndPoint('10.0.0.1'), - DefaultEndPoint('10.0.0.1', 9042) - ) - - self.assertNotEqual( - DefaultEndPoint('10.0.0.1'), - DefaultEndPoint('10.0.0.2') - ) - - self.assertNotEqual( - DefaultEndPoint('10.0.0.1'), - DefaultEndPoint('10.0.0.1', 0000) - ) + assert DefaultEndPoint('10.0.0.1') == DefaultEndPoint('10.0.0.1') + + assert DefaultEndPoint('10.0.0.1') == DefaultEndPoint('10.0.0.1', 9042) + + assert DefaultEndPoint('10.0.0.1') != DefaultEndPoint('10.0.0.2') + + assert DefaultEndPoint('10.0.0.1') != DefaultEndPoint('10.0.0.1', 0000) def test_endpoint_resolve(self): - self.assertEqual( - DefaultEndPoint('10.0.0.1').resolve(), - ('10.0.0.1', 9042) - ) + assert DefaultEndPoint('10.0.0.1').resolve() == ('10.0.0.1', 9042) - self.assertEqual( - DefaultEndPoint('10.0.0.1', 3232).resolve(), - ('10.0.0.1', 3232) - ) + assert DefaultEndPoint('10.0.0.1', 3232).resolve() == ('10.0.0.1', 3232) class TestShardawarePortGenerator(unittest.TestCase): @@ -489,7 +473,7 @@ def test_generate_ports_basic(self, mock_randrange): ports = list(itertools.islice(gen.generate(shard_id=1, total_shards=3), 5)) # Starting from aligned 10005 + shard_id (1), step by 3 - self.assertEqual(ports, [10006, 10009, 10012, 10015, 10018]) + assert ports == [10006, 10009, 10012, 10015, 10018] @patch('random.randrange') def test_wraps_around_to_start(self, mock_randrange): @@ -498,7 +482,7 @@ def test_wraps_around_to_start(self, mock_randrange): ports = list(itertools.islice(gen.generate(shard_id=2, total_shards=4), 5)) # Expected wrap-around from start_port after end_port is exceeded - self.assertEqual(ports, [10010, 10014, 10018, 10002, 10006]) + assert ports == [10010, 10014, 10018, 10002, 10006] @patch('random.randrange') def test_all_ports_have_correct_modulo(self, mock_randrange): @@ -508,7 +492,7 @@ def test_all_ports_have_correct_modulo(self, mock_randrange): gen = ShardAwarePortGenerator(10000, 10020) for port in gen.generate(shard_id=shard_id, total_shards=total_shards): - self.assertEqual(port % total_shards, shard_id) + assert port % total_shards == shard_id @patch('random.randrange') def test_generate_is_repeatable_with_same_mock(self, mock_randrange): @@ -518,4 +502,4 @@ def test_generate_is_repeatable_with_same_mock(self, mock_randrange): first_run = list(itertools.islice(gen.generate(0, 2), 5)) second_run = list(itertools.islice(gen.generate(0, 2), 5)) - self.assertEqual(first_run, second_run) \ No newline at end of file + assert first_run == second_run diff --git a/tests/unit/test_control_connection.py b/tests/unit/test_control_connection.py index 71a6b024cd..d759e12332 100644 --- a/tests/unit/test_control_connection.py +++ b/tests/unit/test_control_connection.py @@ -210,30 +210,30 @@ def test_wait_for_schema_agreement(self): """ Basic test with all schema versions agreeing """ - self.assertTrue(self.control_connection.wait_for_schema_agreement()) + assert self.control_connection.wait_for_schema_agreement() # the control connection should not have slept at all - self.assertEqual(self.time.clock, 0) + assert self.time.clock == 0 def test_wait_for_schema_agreement_uses_preloaded_results_if_given(self): """ wait_for_schema_agreement uses preloaded results if given for shared table queries """ preloaded_results = self._matching_schema_preloaded_results - self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) + assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all - self.assertEqual(self.time.clock, 0) + assert self.time.clock == 0 # the connection should not have made any queries if given preloaded results - self.assertEqual(self.connection.wait_for_responses.call_count, 0) + assert self.connection.wait_for_responses.call_count == 0 def test_wait_for_schema_agreement_falls_back_to_querying_if_schemas_dont_match_preloaded_result(self): """ wait_for_schema_agreement requery if schema does not match using preloaded results """ preloaded_results = self._nonmatching_schema_preloaded_results - self.assertTrue(self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results)) + assert self.control_connection.wait_for_schema_agreement(preloaded_results=preloaded_results) # the control connection should not have slept at all - self.assertEqual(self.time.clock, 0) - self.assertEqual(self.connection.wait_for_responses.call_count, 1) + assert self.time.clock == 0 + assert self.connection.wait_for_responses.call_count == 1 def test_wait_for_schema_agreement_fails(self): """ @@ -241,9 +241,9 @@ def test_wait_for_schema_agreement_fails(self): """ # change the schema version on one node self.connection.peer_results[1][1][2] = 'b' - self.assertFalse(self.control_connection.wait_for_schema_agreement()) + assert not self.control_connection.wait_for_schema_agreement() # the control connection should have slept until it hit the limit - self.assertGreaterEqual(self.time.clock, self.cluster.max_schema_agreement_wait) + assert self.time.clock >= self.cluster.max_schema_agreement_wait def test_wait_for_schema_agreement_skipping(self): """ @@ -262,8 +262,8 @@ def test_wait_for_schema_agreement_skipping(self): self.connection.peer_results[1][1][3] = 'c' self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.1')).is_up = False - self.assertTrue(self.control_connection.wait_for_schema_agreement()) - self.assertEqual(self.time.clock, 0) + assert self.control_connection.wait_for_schema_agreement() + assert self.time.clock == 0 def test_wait_for_schema_agreement_rpc_lookup(self): """ @@ -279,38 +279,38 @@ def test_wait_for_schema_agreement_rpc_lookup(self): # even though the new host has a different schema version, it's # marked as down, so the control connection shouldn't care - self.assertTrue(self.control_connection.wait_for_schema_agreement()) - self.assertEqual(self.time.clock, 0) + assert self.control_connection.wait_for_schema_agreement() + assert self.time.clock == 0 # but once we mark it up, the control connection will care host.is_up = True - self.assertFalse(self.control_connection.wait_for_schema_agreement()) - self.assertGreaterEqual(self.time.clock, self.cluster.max_schema_agreement_wait) + assert not self.control_connection.wait_for_schema_agreement() + assert self.time.clock >= self.cluster.max_schema_agreement_wait def test_refresh_nodes_and_tokens(self): self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata - self.assertEqual(meta.partitioner, 'Murmur3Partitioner') - self.assertEqual(meta.cluster_name, 'foocluster') + assert meta.partitioner == 'Murmur3Partitioner' + assert meta.cluster_name == 'foocluster' # check token map - self.assertEqual(sorted(meta.all_hosts()), sorted(meta.token_map.keys())) + assert sorted(meta.all_hosts()) == sorted(meta.token_map.keys()) for token_list in meta.token_map.values(): - self.assertEqual(3, len(token_list)) + assert 3 == len(token_list) # check datacenter/rack for host in meta.all_hosts(): - self.assertEqual(host.datacenter, "dc1") - self.assertEqual(host.rack, "rack1") + assert host.datacenter == "dc1" + assert host.rack == "rack1" - self.assertEqual(self.connection.wait_for_responses.call_count, 1) + assert self.connection.wait_for_responses.call_count == 1 def test_refresh_nodes_and_tokens_with_invalid_peers(self): def refresh_and_validate_added_hosts(): self.connection.wait_for_responses = Mock(return_value=_node_meta_results( self.connection.local_results, self.connection.peer_results)) self.control_connection.refresh_node_list_and_token_map() - self.assertEqual(1, len(self.cluster.added_hosts)) # only one valid peer found + assert 1 == len(self.cluster.added_hosts) # only one valid peer found # peersV1 del self.connection.peer_results[:] @@ -357,12 +357,12 @@ def test_change_ip(self): self.connection.local_results, self.connection.peer_results)) self.control_connection.refresh_node_list_and_token_map() # all peers are updated - self.assertEqual(0, len(self.cluster.added_hosts)) + assert 0 == len(self.cluster.added_hosts) assert self.cluster.metadata.get_host('192.168.1.5') assert self.cluster.metadata.get_host('192.168.1.6') - self.assertEqual(3, len(self.cluster.metadata.all_hosts())) + assert 3 == len(self.cluster.metadata.all_hosts()) def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): @@ -372,21 +372,21 @@ def test_refresh_nodes_and_tokens_uses_preloaded_results_if_given(self): preloaded_results = self._matching_schema_preloaded_results self.control_connection._refresh_node_list_and_token_map(self.connection, preloaded_results=preloaded_results) meta = self.cluster.metadata - self.assertEqual(meta.partitioner, 'Murmur3Partitioner') - self.assertEqual(meta.cluster_name, 'foocluster') + assert meta.partitioner == 'Murmur3Partitioner' + assert meta.cluster_name == 'foocluster' # check token map - self.assertEqual(sorted(meta.all_hosts()), sorted(meta.token_map.keys())) + assert sorted(meta.all_hosts()) == sorted(meta.token_map.keys()) for token_list in meta.token_map.values(): - self.assertEqual(3, len(token_list)) + assert 3 == len(token_list) # check datacenter/rack for host in meta.all_hosts(): - self.assertEqual(host.datacenter, "dc1") - self.assertEqual(host.rack, "rack1") + assert host.datacenter == "dc1" + assert host.rack == "rack1" # the connection should not have made any queries if given preloaded results - self.assertEqual(self.connection.wait_for_responses.call_count, 0) + assert self.connection.wait_for_responses.call_count == 0 def test_refresh_nodes_and_tokens_no_partitioner(self): """ @@ -396,8 +396,8 @@ def test_refresh_nodes_and_tokens_no_partitioner(self): self.connection.local_results[1][0][5] = None self.control_connection.refresh_node_list_and_token_map() meta = self.cluster.metadata - self.assertEqual(meta.partitioner, None) - self.assertEqual(meta.token_map, {}) + assert meta.partitioner == None + assert meta.token_map == {} def test_refresh_nodes_and_tokens_add_host(self): self.connection.peer_results[1].append( @@ -405,22 +405,22 @@ def test_refresh_nodes_and_tokens_add_host(self): ) self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() - self.assertEqual(1, len(self.cluster.added_hosts)) - self.assertEqual(self.cluster.added_hosts[0].address, "192.168.1.3") - self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") - self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") - self.assertEqual(self.cluster.added_hosts[0].host_id, "uuid4") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0].address == "192.168.1.3" + assert self.cluster.added_hosts[0].datacenter == "dc1" + assert self.cluster.added_hosts[0].rack == "rack1" + assert self.cluster.added_hosts[0].host_id == "uuid4" def test_refresh_nodes_and_tokens_remove_host(self): del self.connection.peer_results[1][1] self.control_connection.refresh_node_list_and_token_map() - self.assertEqual(1, len(self.cluster.metadata.removed_hosts)) - self.assertEqual(self.cluster.metadata.removed_hosts[0].address, "192.168.1.2") + assert 1 == len(self.cluster.metadata.removed_hosts) + assert self.cluster.metadata.removed_hosts[0].address == "192.168.1.2" def test_refresh_nodes_and_tokens_timeout(self): def bad_wait_for_responses(*args, **kwargs): - self.assertEqual(kwargs['timeout'], self.control_connection._timeout) + assert kwargs['timeout'] == self.control_connection._timeout raise OperationTimedOut() self.connection.wait_for_responses = bad_wait_for_responses @@ -435,8 +435,8 @@ def bad_wait_for_responses(*args, **kwargs): self.connection.wait_for_responses = Mock(side_effect=bad_wait_for_responses) self.control_connection.refresh_schema() - self.assertEqual(self.connection.wait_for_responses.call_count, self.cluster.max_schema_agreement_wait / self.control_connection._timeout) - self.assertEqual(self.connection.wait_for_responses.call_args[1]['timeout'], self.control_connection._timeout) + assert self.connection.wait_for_responses.call_count == self.cluster.max_schema_agreement_wait / self.control_connection._timeout + assert self.connection.wait_for_responses.call_args[1]['timeout'] == self.control_connection._timeout def test_handle_topology_change(self): event = { @@ -489,7 +489,7 @@ def test_handle_status_change(self): 'address': ('1.2.3.4', 9000) } self.control_connection._handle_status_change(event) - self.assertFalse(self.cluster.scheduler.schedule.called) + assert not self.cluster.scheduler.schedule.called # do the same with a known Host event = { @@ -498,7 +498,7 @@ def test_handle_status_change(self): } self.control_connection._handle_status_change(event) host = self.cluster.metadata.get_host(DefaultEndPoint('192.168.1.0')) - self.assertIs(host, self.cluster.down_host) + assert host is self.cluster.down_host def test_handle_schema_change(self): @@ -545,8 +545,8 @@ def test_refresh_disabled(self): # no call on schema refresh cc_no_schema_refresh._handle_schema_change(schema_event) - self.assertFalse(cluster.scheduler.schedule.called) - self.assertFalse(cluster.scheduler.schedule_unique.called) + assert not cluster.scheduler.schedule.called + assert not cluster.scheduler.schedule_unique.called # topo and status changes as normal cc_no_schema_refresh._handle_status_change(status_event) @@ -559,8 +559,8 @@ def test_refresh_disabled(self): # no call on topo refresh cc_no_topo_refresh._handle_topology_change(topo_event) - self.assertFalse(cluster.scheduler.schedule.called) - self.assertFalse(cluster.scheduler.schedule_unique.called) + assert not cluster.scheduler.schedule.called + assert not cluster.scheduler.schedule_unique.called # schema and status change refresh as normal cc_no_topo_refresh._handle_status_change(status_event) @@ -579,15 +579,15 @@ def test_refresh_nodes_and_tokens_add_host_detects_port(self): self.connection.local_results, self.connection.peer_results)) self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() - self.assertEqual(1, len(self.cluster.added_hosts)) - self.assertEqual(self.cluster.added_hosts[0].endpoint.address, "192.168.1.3") - self.assertEqual(self.cluster.added_hosts[0].endpoint.port, 555) - self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_address, "192.168.1.3") - self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_port, 555) - self.assertEqual(self.cluster.added_hosts[0].broadcast_address, "10.0.0.3") - self.assertEqual(self.cluster.added_hosts[0].broadcast_port, 666) - self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") - self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0].endpoint.address == "192.168.1.3" + assert self.cluster.added_hosts[0].endpoint.port == 555 + assert self.cluster.added_hosts[0].broadcast_rpc_address == "192.168.1.3" + assert self.cluster.added_hosts[0].broadcast_rpc_port == 555 + assert self.cluster.added_hosts[0].broadcast_address == "10.0.0.3" + assert self.cluster.added_hosts[0].broadcast_port == 666 + assert self.cluster.added_hosts[0].datacenter == "dc1" + assert self.cluster.added_hosts[0].rack == "rack1" def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): del self.connection.peer_results[:] @@ -599,15 +599,15 @@ def test_refresh_nodes_and_tokens_add_host_detects_invalid_port(self): self.connection.local_results, self.connection.peer_results)) self.cluster.scheduler.schedule = lambda delay, f, *args, **kwargs: f(*args, **kwargs) self.control_connection.refresh_node_list_and_token_map() - self.assertEqual(1, len(self.cluster.added_hosts)) - self.assertEqual(self.cluster.added_hosts[0].endpoint.address, "192.168.1.3") - self.assertEqual(self.cluster.added_hosts[0].endpoint.port, 9042) # fallback default - self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_address, "192.168.1.3") - self.assertEqual(self.cluster.added_hosts[0].broadcast_rpc_port, None) - self.assertEqual(self.cluster.added_hosts[0].broadcast_address, "10.0.0.3") - self.assertEqual(self.cluster.added_hosts[0].broadcast_port, None) - self.assertEqual(self.cluster.added_hosts[0].datacenter, "dc1") - self.assertEqual(self.cluster.added_hosts[0].rack, "rack1") + assert 1 == len(self.cluster.added_hosts) + assert self.cluster.added_hosts[0].endpoint.address == "192.168.1.3" + assert self.cluster.added_hosts[0].endpoint.port == 9042 # fallback default + assert self.cluster.added_hosts[0].broadcast_rpc_address == "192.168.1.3" + assert self.cluster.added_hosts[0].broadcast_rpc_port == None + assert self.cluster.added_hosts[0].broadcast_address == "10.0.0.3" + assert self.cluster.added_hosts[0].broadcast_port == None + assert self.cluster.added_hosts[0].datacenter == "dc1" + assert self.cluster.added_hosts[0].rack == "rack1" class EventTimingTest(unittest.TestCase): @@ -644,5 +644,5 @@ def test_event_delay_timing(self): self.cluster.scheduler.mock_calls # Grabs the delay parameter from the scheduler invocation current_delay = self.cluster.scheduler.mock_calls[0][1][0] - self.assertLess(prior_delay, current_delay) + assert prior_delay < current_delay prior_delay = current_delay diff --git a/tests/unit/test_endpoints.py b/tests/unit/test_endpoints.py index b0841962ca..14fb8b5806 100644 --- a/tests/unit/test_endpoints.py +++ b/tests/unit/test_endpoints.py @@ -31,31 +31,19 @@ class SniEndPointTest(unittest.TestCase): def test_sni_endpoint_properties(self): endpoint = self.endpoint_factory.create_from_sni('test') - self.assertEqual(endpoint.address, 'proxy.datastax.com') - self.assertEqual(endpoint.port, 30002) - self.assertEqual(endpoint._server_name, 'test') - self.assertEqual(str(endpoint), 'proxy.datastax.com:30002:test') + assert endpoint.address == 'proxy.datastax.com' + assert endpoint.port == 30002 + assert endpoint._server_name == 'test' + assert str(endpoint) == 'proxy.datastax.com:30002:test' def test_endpoint_equality(self): - self.assertNotEqual( - DefaultEndPoint('10.0.0.1'), - self.endpoint_factory.create_from_sni('10.0.0.1') - ) - - self.assertEqual( - self.endpoint_factory.create_from_sni('10.0.0.1'), - self.endpoint_factory.create_from_sni('10.0.0.1') - ) - - self.assertNotEqual( - self.endpoint_factory.create_from_sni('10.0.0.1'), - self.endpoint_factory.create_from_sni('10.0.0.0') - ) - - self.assertNotEqual( - self.endpoint_factory.create_from_sni('10.0.0.1'), - SniEndPointFactory("proxy.datastax.com", 9999).create_from_sni('10.0.0.1') - ) + assert DefaultEndPoint('10.0.0.1') != self.endpoint_factory.create_from_sni('10.0.0.1') + + assert self.endpoint_factory.create_from_sni('10.0.0.1') == self.endpoint_factory.create_from_sni('10.0.0.1') + + assert self.endpoint_factory.create_from_sni('10.0.0.1') != self.endpoint_factory.create_from_sni('10.0.0.0') + + assert self.endpoint_factory.create_from_sni('10.0.0.1') != SniEndPointFactory("proxy.datastax.com", 9999).create_from_sni('10.0.0.1') def test_endpoint_resolve(self): ips = ['127.0.0.1', '127.0.0.2', '127.0.0.3'] @@ -64,4 +52,4 @@ def test_endpoint_resolve(self): endpoint = self.endpoint_factory.create_from_sni('test') for i in range(10): (address, _) = endpoint.resolve() - self.assertEqual(address, next(it)) + assert address == next(it) diff --git a/tests/unit/test_exception.py b/tests/unit/test_exception.py index b39b22239c..6bddd96a4b 100644 --- a/tests/unit/test_exception.py +++ b/tests/unit/test_exception.py @@ -37,17 +37,17 @@ def test_timeout_consistency(self): Verify that Timeout exception object translates consistency from input value to correct output string """ consistency_str = self.extract_consistency(repr(Timeout("Timeout Message", consistency=None))) - self.assertEqual(consistency_str, 'Not Set') + assert consistency_str == 'Not Set' for c in ConsistencyLevel.value_to_name.keys(): consistency_str = self.extract_consistency(repr(Timeout("Timeout Message", consistency=c))) - self.assertEqual(consistency_str, ConsistencyLevel.value_to_name[c]) + assert consistency_str == ConsistencyLevel.value_to_name[c] def test_unavailable_consistency(self): """ Verify that Unavailable exception object translates consistency from input value to correct output string """ consistency_str = self.extract_consistency(repr(Unavailable("Unavailable Message", consistency=None))) - self.assertEqual(consistency_str, 'Not Set') + assert consistency_str == 'Not Set' for c in ConsistencyLevel.value_to_name.keys(): consistency_str = self.extract_consistency(repr(Unavailable("Timeout Message", consistency=c))) - self.assertEqual(consistency_str, ConsistencyLevel.value_to_name[c]) + assert consistency_str == ConsistencyLevel.value_to_name[c] diff --git a/tests/unit/test_host_connection_pool.py b/tests/unit/test_host_connection_pool.py index 252ccb49ca..3fac0b18ef 100644 --- a/tests/unit/test_host_connection_pool.py +++ b/tests/unit/test_host_connection_pool.py @@ -27,6 +27,7 @@ from cassandra.pool import HostConnection from cassandra.pool import Host, NoConnectionsAvailable from cassandra.policies import HostDistance, SimpleConvictionPolicy +import pytest LOGGER = logging.getLogger(__name__) @@ -52,14 +53,14 @@ def test_borrow_and_return(self): session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) c, request_id = pool.borrow_connection(timeout=0.01) - self.assertIs(c, conn) - self.assertEqual(1, conn.in_flight) + assert c is conn + assert 1 == conn.in_flight conn.set_keyspace_blocking.assert_called_once_with('foobarkeyspace') pool.return_connection(conn) - self.assertEqual(0, conn.in_flight) + assert 0 == conn.in_flight if not self.uses_single_connection: - self.assertNotIn(conn, pool._trash) + assert conn not in pool._trash def test_failed_wait_for_connection(self): host = Mock(spec=Host, address='ip1') @@ -71,13 +72,14 @@ def test_failed_wait_for_connection(self): session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) - self.assertEqual(1, conn.in_flight) + assert 1 == conn.in_flight conn.in_flight = conn.max_request_id # we're already at the max number of requests for this connection, # so we this should fail - self.assertRaises(NoConnectionsAvailable, pool.borrow_connection, 0) + with pytest.raises(NoConnectionsAvailable): + pool.borrow_connection(0) def test_successful_wait_for_connection(self): host = Mock(spec=Host, address='ip1') @@ -90,11 +92,11 @@ def test_successful_wait_for_connection(self): session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) - self.assertEqual(1, conn.in_flight) + assert 1 == conn.in_flight def get_second_conn(): c, request_id = pool.borrow_connection(1.0) - self.assertIs(conn, c) + assert conn is c pool.return_connection(c) t = Thread(target=get_second_conn) @@ -102,7 +104,7 @@ def get_second_conn(): pool.return_connection(conn) t.join() - self.assertEqual(0, conn.in_flight) + assert 0 == conn.in_flight def test_spawn_when_at_max(self): host = Mock(spec=Host, address='ip1') @@ -118,7 +120,7 @@ def test_spawn_when_at_max(self): session.cluster.connection_factory.assert_called_once_with(host.endpoint, on_orphaned_stream_released=pool.on_orphaned_stream_released) pool.borrow_connection(timeout=0.01) - self.assertEqual(1, conn.in_flight) + assert 1 == conn.in_flight # make this conn full conn.in_flight = conn.max_request_id @@ -126,7 +128,8 @@ def test_spawn_when_at_max(self): # we don't care about making this borrow_connection call succeed for the # purposes of this test, as long as it results in a new connection # creation being scheduled - self.assertRaises(NoConnectionsAvailable, pool.borrow_connection, 0) + with pytest.raises(NoConnectionsAvailable): + pool.borrow_connection(0) if not self.uses_single_connection: session.submit.assert_called_once_with(pool._create_new_connection) @@ -147,8 +150,8 @@ def test_return_defunct_connection(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - self.assertTrue(session.submit.call_args) - self.assertFalse(pool.is_shutdown) + assert session.submit.call_args + assert not pool.is_shutdown def test_return_defunct_connection_on_down_host(self): host = Mock(spec=Host, address='ip1') @@ -169,15 +172,15 @@ def test_return_defunct_connection_on_down_host(self): pool.return_connection(conn) # the connection should be closed a new creation scheduled - self.assertTrue(conn.close.call_args) + assert conn.close.call_args if self.PoolImpl is HostConnection: # on shard aware implementation we use submit function regardless - self.assertTrue(host.signal_connection_failure.call_args) - self.assertTrue(session.submit.called) + assert host.signal_connection_failure.call_args + assert session.submit.called else: - self.assertFalse(session.submit.called) - self.assertTrue(session.cluster.signal_connection_failure.call_args) - self.assertTrue(pool.is_shutdown) + assert not session.submit.called + assert session.cluster.signal_connection_failure.call_args + assert pool.is_shutdown def test_return_closed_connection(self): host = Mock(spec=Host, address='ip1') @@ -196,17 +199,20 @@ def test_return_closed_connection(self): pool.return_connection(conn) # a new creation should be scheduled - self.assertTrue(session.submit.call_args) - self.assertFalse(pool.is_shutdown) + assert session.submit.call_args + assert not pool.is_shutdown def test_host_instantiations(self): """ Ensure Host fails if not initialized properly """ - self.assertRaises(ValueError, Host, None, None) - self.assertRaises(ValueError, Host, '127.0.0.1', None) - self.assertRaises(ValueError, Host, None, SimpleConvictionPolicy) + with pytest.raises(ValueError): + Host(None, None) + with pytest.raises(ValueError): + Host('127.0.0.1', None) + with pytest.raises(ValueError): + Host(None, SimpleConvictionPolicy) def test_host_equality(self): """ @@ -217,9 +223,9 @@ def test_host_equality(self): b = Host('127.0.0.1', SimpleConvictionPolicy) c = Host('127.0.0.2', SimpleConvictionPolicy) - self.assertEqual(a, b, 'Two Host instances should be equal when sharing.') - self.assertNotEqual(a, c, 'Two Host instances should NOT be equal when using two different addresses.') - self.assertNotEqual(b, c, 'Two Host instances should NOT be equal when using two different addresses.') + assert a == b, 'Two Host instances should be equal when sharing.' + assert a != c, 'Two Host instances should NOT be equal when using two different addresses.' + assert b != c, 'Two Host instances should NOT be equal when using two different addresses.' class HostConnectionTests(_PoolTests): diff --git a/tests/unit/test_marshalling.py b/tests/unit/test_marshalling.py index 9c368860f3..e4b415ac69 100644 --- a/tests/unit/test_marshalling.py +++ b/tests/unit/test_marshalling.py @@ -112,27 +112,19 @@ def test_unmarshalling(self): for serializedval, valtype, nativeval in marshalled_value_pairs: unmarshaller = lookup_casstype(valtype) whatwegot = unmarshaller.from_binary(serializedval, 3) - self.assertEqual(whatwegot, nativeval, - msg='Unmarshaller for %s (%s) failed: unmarshal(%r) got %r instead of %r' - % (valtype, unmarshaller, serializedval, whatwegot, nativeval)) - self.assertEqual(type(whatwegot), type(nativeval), - msg='Unmarshaller for %s (%s) gave wrong type (%s instead of %s)' - % (valtype, unmarshaller, type(whatwegot), type(nativeval))) + assert whatwegot == nativeval, 'Unmarshaller for %s (%s) failed: unmarshal(%r) got %r instead of %r' % (valtype, unmarshaller, serializedval, whatwegot, nativeval) + assert type(whatwegot) == type(nativeval), 'Unmarshaller for %s (%s) gave wrong type (%s instead of %s)' % (valtype, unmarshaller, type(whatwegot), type(nativeval)) def test_marshalling(self): for serializedval, valtype, nativeval in marshalled_value_pairs: marshaller = lookup_casstype(valtype) whatwegot = marshaller.to_binary(nativeval, 3) - self.assertEqual(whatwegot, serializedval, - msg='Marshaller for %s (%s) failed: marshal(%r) got %r instead of %r' - % (valtype, marshaller, nativeval, whatwegot, serializedval)) - self.assertEqual(type(whatwegot), type(serializedval), - msg='Marshaller for %s (%s) gave wrong type (%s instead of %s)' - % (valtype, marshaller, type(whatwegot), type(serializedval))) + assert whatwegot == serializedval, 'Marshaller for %s (%s) failed: marshal(%r) got %r instead of %r' % (valtype, marshaller, nativeval, whatwegot, serializedval) + assert type(whatwegot) == type(serializedval), 'Marshaller for %s (%s) gave wrong type (%s instead of %s)' % (valtype, marshaller, type(whatwegot), type(serializedval)) def test_date(self): # separate test because it will deserialize as datetime - self.assertEqual(DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 3), 3), datetime(2015, 11, 2)) + assert DateType.from_binary(DateType.to_binary(date(2015, 11, 2), 3), 3) == datetime(2015, 11, 2) def test_decimal(self): # testing implicit numeric conversion @@ -142,4 +134,4 @@ def test_decimal(self): for proto_ver in range(3, ProtocolVersion.MAX_SUPPORTED + 1): for n in converted_types: expected = Decimal(n) - self.assertEqual(DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver), expected) + assert DecimalType.from_binary(DecimalType.to_binary(n, proto_ver), proto_ver) == expected diff --git a/tests/unit/test_metadata.py b/tests/unit/test_metadata.py index bc5a93bf89..3069f6bced 100644 --- a/tests/unit/test_metadata.py +++ b/tests/unit/test_metadata.py @@ -33,6 +33,8 @@ Metadata, TokenMap, ReplicationFactor) from cassandra.policies import SimpleConvictionPolicy from cassandra.pool import Host +from tests.util import assertCountEqual +import pytest log = logging.getLogger(__name__) @@ -42,38 +44,36 @@ class ReplicationFactorTest(unittest.TestCase): def test_replication_factor_parsing(self): rf = ReplicationFactor.create('3') - self.assertEqual(rf.all_replicas, 3) - self.assertEqual(rf.full_replicas, 3) - self.assertEqual(rf.transient_replicas, None) - self.assertEqual(str(rf), '3') + assert rf.all_replicas == 3 + assert rf.full_replicas == 3 + assert rf.transient_replicas == None + assert str(rf) == '3' rf = ReplicationFactor.create('3/1') - self.assertEqual(rf.all_replicas, 3) - self.assertEqual(rf.full_replicas, 2) - self.assertEqual(rf.transient_replicas, 1) - self.assertEqual(str(rf), '3/1') - - self.assertRaises(ValueError, ReplicationFactor.create, '3/') - self.assertRaises(ValueError, ReplicationFactor.create, 'a/1') - self.assertRaises(ValueError, ReplicationFactor.create, 'a') - self.assertRaises(ValueError, ReplicationFactor.create, '3/a') + assert rf.all_replicas == 3 + assert rf.full_replicas == 2 + assert rf.transient_replicas == 1 + assert str(rf) == '3/1' + + with pytest.raises(ValueError): + ReplicationFactor.create('3/') + with pytest.raises(ValueError): + ReplicationFactor.create('a/1') + with pytest.raises(ValueError): + ReplicationFactor.create('a') + with pytest.raises(ValueError): + ReplicationFactor.create('3/a') def test_replication_factor_equality(self): - self.assertEqual(ReplicationFactor.create('3/1'), ReplicationFactor.create('3/1')) - self.assertEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3')) - self.assertNotEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3/1')) - self.assertNotEqual(ReplicationFactor.create('3'), ReplicationFactor.create('3/1')) + assert ReplicationFactor.create('3/1') == ReplicationFactor.create('3/1') + assert ReplicationFactor.create('3') == ReplicationFactor.create('3') + assert ReplicationFactor.create('3') != ReplicationFactor.create('3/1') + assert ReplicationFactor.create('3') != ReplicationFactor.create('3/1') class StrategiesTest(unittest.TestCase): - @classmethod - def setUpClass(cls): - "Hook method for setting up class fixture before running tests in the class." - if not hasattr(cls, 'assertItemsEqual'): - cls.assertItemsEqual = cls.assertCountEqual - def test_replication_strategy(self): """ Basic code coverage testing that ensures different ReplicationStrategies @@ -82,32 +82,32 @@ def test_replication_strategy(self): rs = ReplicationStrategy() - self.assertEqual(rs.create('OldNetworkTopologyStrategy', None), _UnknownStrategy('OldNetworkTopologyStrategy', None)) + assert rs.create('OldNetworkTopologyStrategy', None) == _UnknownStrategy('OldNetworkTopologyStrategy', None) fake_options_map = {'options': 'map'} uks = rs.create('OldNetworkTopologyStrategy', fake_options_map) - self.assertEqual(uks, _UnknownStrategy('OldNetworkTopologyStrategy', fake_options_map)) - self.assertEqual(uks.make_token_replica_map({}, []), {}) + assert uks == _UnknownStrategy('OldNetworkTopologyStrategy', fake_options_map) + assert uks.make_token_replica_map({}, []) == {} fake_options_map = {'dc1': '3'} - self.assertIsInstance(rs.create('NetworkTopologyStrategy', fake_options_map), NetworkTopologyStrategy) - self.assertEqual(rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors, - NetworkTopologyStrategy(fake_options_map).dc_replication_factors) + assert isinstance(rs.create('NetworkTopologyStrategy', fake_options_map), NetworkTopologyStrategy) + assert rs.create('NetworkTopologyStrategy', fake_options_map).dc_replication_factors == NetworkTopologyStrategy(fake_options_map).dc_replication_factors fake_options_map = {'options': 'map'} - self.assertIsNone(rs.create('SimpleStrategy', fake_options_map)) + assert rs.create('SimpleStrategy', fake_options_map) is None fake_options_map = {'options': 'map'} - self.assertIsInstance(rs.create('LocalStrategy', fake_options_map), LocalStrategy) + assert isinstance(rs.create('LocalStrategy', fake_options_map), LocalStrategy) fake_options_map = {'options': 'map', 'replication_factor': 3} - self.assertIsInstance(rs.create('SimpleStrategy', fake_options_map), SimpleStrategy) - self.assertEqual(rs.create('SimpleStrategy', fake_options_map).replication_factor, - SimpleStrategy(fake_options_map).replication_factor) + assert isinstance(rs.create('SimpleStrategy', fake_options_map), SimpleStrategy) + assert rs.create('SimpleStrategy', fake_options_map).replication_factor == SimpleStrategy(fake_options_map).replication_factor - self.assertEqual(rs.create('xxxxxxxx', fake_options_map), _UnknownStrategy('xxxxxxxx', fake_options_map)) + assert rs.create('xxxxxxxx', fake_options_map) == _UnknownStrategy('xxxxxxxx', fake_options_map) - self.assertRaises(NotImplementedError, rs.make_token_replica_map, None, None) - self.assertRaises(NotImplementedError, rs.export_for_schema) + with pytest.raises(NotImplementedError): + rs.make_token_replica_map(None, None) + with pytest.raises(NotImplementedError): + rs.export_for_schema() def test_simple_replication_type_parsing(self): """ Test equality between passing numeric and string replication factor for simple strategy """ @@ -116,38 +116,32 @@ def test_simple_replication_type_parsing(self): simple_int = rs.create('SimpleStrategy', {'replication_factor': 3}) simple_str = rs.create('SimpleStrategy', {'replication_factor': '3'}) - self.assertEqual(simple_int.export_for_schema(), simple_str.export_for_schema()) - self.assertEqual(simple_int, simple_str) + assert simple_int.export_for_schema() == simple_str.export_for_schema() + assert simple_int == simple_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - self.assertEqual( - simple_int.make_token_replica_map(token_to_host, ring), - simple_str.make_token_replica_map(token_to_host, ring) - ) + assert simple_int.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) def test_transient_replication_parsing(self): """ Test that we can PARSE a transient replication factor for SimpleStrategy """ rs = ReplicationStrategy() simple_transient = rs.create('SimpleStrategy', {'replication_factor': '3/1'}) - self.assertEqual(simple_transient.replication_factor_info, ReplicationFactor(3, 1)) - self.assertEqual(simple_transient.replication_factor, 2) - self.assertIn("'replication_factor': '3/1'", simple_transient.export_for_schema()) + assert simple_transient.replication_factor_info == ReplicationFactor(3, 1) + assert simple_transient.replication_factor == 2 + assert "'replication_factor': '3/1'" in simple_transient.export_for_schema() simple_str = rs.create('SimpleStrategy', {'replication_factor': '2'}) - self.assertNotEqual(simple_transient, simple_str) + assert simple_transient != simple_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - self.assertEqual( - simple_transient.make_token_replica_map(token_to_host, ring), - simple_str.make_token_replica_map(token_to_host, ring) - ) + assert simple_transient.make_token_replica_map(token_to_host, ring) == simple_str.make_token_replica_map(token_to_host, ring) def test_nts_replication_parsing(self): """ Test equality between passing numeric and string replication factor for NTS """ @@ -156,45 +150,39 @@ def test_nts_replication_parsing(self): nts_int = rs.create('NetworkTopologyStrategy', {'dc1': 3, 'dc2': 5}) nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) - self.assertEqual(nts_int.dc_replication_factors['dc1'], 3) - self.assertEqual(nts_str.dc_replication_factors['dc1'], 3) - self.assertEqual(nts_int.dc_replication_factors_info['dc1'], ReplicationFactor(3)) - self.assertEqual(nts_str.dc_replication_factors_info['dc1'], ReplicationFactor(3)) + assert nts_int.dc_replication_factors['dc1'] == 3 + assert nts_str.dc_replication_factors['dc1'] == 3 + assert nts_int.dc_replication_factors_info['dc1'] == ReplicationFactor(3) + assert nts_str.dc_replication_factors_info['dc1'] == ReplicationFactor(3) - self.assertEqual(nts_int.export_for_schema(), nts_str.export_for_schema()) - self.assertEqual(nts_int, nts_str) + assert nts_int.export_for_schema() == nts_str.export_for_schema() + assert nts_int == nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - self.assertEqual( - nts_int.make_token_replica_map(token_to_host, ring), - nts_str.make_token_replica_map(token_to_host, ring) - ) + assert nts_int.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_transient_parsing(self): """ Test that we can PARSE a transient replication factor for NTS """ rs = ReplicationStrategy() nts_transient = rs.create('NetworkTopologyStrategy', {'dc1': '3/1', 'dc2': '5/1'}) - self.assertEqual(nts_transient.dc_replication_factors_info['dc1'], ReplicationFactor(3, 1)) - self.assertEqual(nts_transient.dc_replication_factors_info['dc2'], ReplicationFactor(5, 1)) - self.assertEqual(nts_transient.dc_replication_factors['dc1'], 2) - self.assertEqual(nts_transient.dc_replication_factors['dc2'], 4) - self.assertIn("'dc1': '3/1', 'dc2': '5/1'", nts_transient.export_for_schema()) + assert nts_transient.dc_replication_factors_info['dc1'] == ReplicationFactor(3, 1) + assert nts_transient.dc_replication_factors_info['dc2'] == ReplicationFactor(5, 1) + assert nts_transient.dc_replication_factors['dc1'] == 2 + assert nts_transient.dc_replication_factors['dc2'] == 4 + assert "'dc1': '3/1', 'dc2': '5/1'" in nts_transient.export_for_schema() nts_str = rs.create('NetworkTopologyStrategy', {'dc1': '3', 'dc2': '5'}) - self.assertNotEqual(nts_transient, nts_str) + assert nts_transient != nts_str # make token replica map ring = [MD5Token(0), MD5Token(1), MD5Token(2)] hosts = [Host('dc1.{}'.format(host), SimpleConvictionPolicy) for host in range(3)] token_to_host = dict(zip(ring, hosts)) - self.assertEqual( - nts_transient.make_token_replica_map(token_to_host, ring), - nts_str.make_token_replica_map(token_to_host, ring) - ) + assert nts_transient.make_token_replica_map(token_to_host, ring) == nts_str.make_token_replica_map(token_to_host, ring) def test_nts_make_token_replica_map(self): token_to_host_owner = {} @@ -229,7 +217,7 @@ def test_nts_make_token_replica_map(self): nts = NetworkTopologyStrategy({'dc1': 2, 'dc2': 2, 'dc3': 1}) replica_map = nts.make_token_replica_map(token_to_host_owner, ring) - self.assertItemsEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) + assertCountEqual(replica_map[MD5Token(0)], (dc1_1, dc1_2, dc2_1, dc2_2, dc3_1)) def test_nts_token_performance(self): """ @@ -268,7 +256,7 @@ def test_nts_token_performance(self): nts.make_token_replica_map(token_to_host_owner, ring) elapsed_bad = timeit.default_timer() - start_time difference = elapsed_bad - elapsed_base - self.assertTrue(difference < 1 and difference > -1) + assert difference < 1 and difference > -1 def test_nts_make_token_replica_map_multi_rack(self): token_to_host_owner = {} @@ -310,7 +298,7 @@ def test_nts_make_token_replica_map_multi_rack(self): replica_map = nts.make_token_replica_map(token_to_host_owner, ring) token_replicas = replica_map[MD5Token(0)] - self.assertItemsEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) + assertCountEqual(token_replicas, (dc1_1, dc1_2, dc1_3, dc2_1, dc2_3)) def test_nts_make_token_replica_map_empty_dc(self): host = Host('1', SimpleConvictionPolicy) @@ -320,12 +308,11 @@ def test_nts_make_token_replica_map_empty_dc(self): nts = NetworkTopologyStrategy({'dc1': 1, 'dc2': 0}) replica_map = nts.make_token_replica_map(token_to_host_owner, ring) - self.assertEqual(set(replica_map[MD5Token(0)]), set([host])) + assert set(replica_map[MD5Token(0)]) == set([host]) def test_nts_export_for_schema(self): strategy = NetworkTopologyStrategy({'dc1': '1', 'dc2': '2'}) - self.assertEqual("{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}", - strategy.export_for_schema()) + assert "{'class': 'NetworkTopologyStrategy', 'dc1': '1', 'dc2': '2'}" == strategy.export_for_schema() def test_simple_strategy_make_token_replica_map(self): host1 = Host('1', SimpleConvictionPolicy) @@ -339,22 +326,22 @@ def test_simple_strategy_make_token_replica_map(self): ring = [MD5Token(0), MD5Token(100), MD5Token(200)] rf1_replicas = SimpleStrategy({'replication_factor': '1'}).make_token_replica_map(token_to_host_owner, ring) - self.assertItemsEqual(rf1_replicas[MD5Token(0)], [host1]) - self.assertItemsEqual(rf1_replicas[MD5Token(100)], [host2]) - self.assertItemsEqual(rf1_replicas[MD5Token(200)], [host3]) + assertCountEqual(rf1_replicas[MD5Token(0)], [host1]) + assertCountEqual(rf1_replicas[MD5Token(100)], [host2]) + assertCountEqual(rf1_replicas[MD5Token(200)], [host3]) rf2_replicas = SimpleStrategy({'replication_factor': '2'}).make_token_replica_map(token_to_host_owner, ring) - self.assertItemsEqual(rf2_replicas[MD5Token(0)], [host1, host2]) - self.assertItemsEqual(rf2_replicas[MD5Token(100)], [host2, host3]) - self.assertItemsEqual(rf2_replicas[MD5Token(200)], [host3, host1]) + assertCountEqual(rf2_replicas[MD5Token(0)], [host1, host2]) + assertCountEqual(rf2_replicas[MD5Token(100)], [host2, host3]) + assertCountEqual(rf2_replicas[MD5Token(200)], [host3, host1]) rf3_replicas = SimpleStrategy({'replication_factor': '3'}).make_token_replica_map(token_to_host_owner, ring) - self.assertItemsEqual(rf3_replicas[MD5Token(0)], [host1, host2, host3]) - self.assertItemsEqual(rf3_replicas[MD5Token(100)], [host2, host3, host1]) - self.assertItemsEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) + assertCountEqual(rf3_replicas[MD5Token(0)], [host1, host2, host3]) + assertCountEqual(rf3_replicas[MD5Token(100)], [host2, host3, host1]) + assertCountEqual(rf3_replicas[MD5Token(200)], [host3, host1, host2]) def test_ss_equals(self): - self.assertNotEqual(SimpleStrategy({'replication_factor': '1'}), NetworkTopologyStrategy({'dc1': 2})) + assert SimpleStrategy({'replication_factor': '1'}) != NetworkTopologyStrategy({'dc1': 2}) class NameEscapingTest(unittest.TestCase): @@ -363,58 +350,57 @@ def test_protect_name(self): """ Test cassandra.metadata.protect_name output """ - self.assertEqual(protect_name('tests'), 'tests') - self.assertEqual(protect_name('test\'s'), '"test\'s"') - self.assertEqual(protect_name('test\'s'), "\"test's\"") - self.assertEqual(protect_name('tests ?!@#$%^&*()'), '"tests ?!@#$%^&*()"') - self.assertEqual(protect_name('1'), '"1"') - self.assertEqual(protect_name('1test'), '"1test"') + assert protect_name('tests') == 'tests' + assert protect_name('test\'s') == '"test\'s"' + assert protect_name('test\'s') == "\"test's\"" + assert protect_name('tests ?!@#$%^&*()') == '"tests ?!@#$%^&*()"' + assert protect_name('1') == '"1"' + assert protect_name('1test') == '"1test"' def test_protect_names(self): """ Test cassandra.metadata.protect_names output """ - self.assertEqual(protect_names(['tests']), ['tests']) - self.assertEqual(protect_names( + assert protect_names(['tests']) == ['tests'] + assert protect_names( [ 'tests', 'test\'s', 'tests ?!@#$%^&*()', '1' - ]), - [ - 'tests', - "\"test's\"", - '"tests ?!@#$%^&*()"', - '"1"' - ]) + ]) == [ + 'tests', + "\"test's\"", + '"tests ?!@#$%^&*()"', + '"1"' + ] def test_protect_value(self): """ Test cassandra.metadata.protect_value output """ - self.assertEqual(protect_value(True), "true") - self.assertEqual(protect_value(False), "false") - self.assertEqual(protect_value(3.14), '3.14') - self.assertEqual(protect_value(3), '3') - self.assertEqual(protect_value('test'), "'test'") - self.assertEqual(protect_value('test\'s'), "'test''s'") - self.assertEqual(protect_value(None), 'NULL') + assert protect_value(True) == "true" + assert protect_value(False) == "false" + assert protect_value(3.14) == '3.14' + assert protect_value(3) == '3' + assert protect_value('test') == "'test'" + assert protect_value('test\'s') == "'test''s'" + assert protect_value(None) == 'NULL' def test_is_valid_name(self): """ Test cassandra.metadata.is_valid_name output """ - self.assertEqual(is_valid_name(None), False) - self.assertEqual(is_valid_name('test'), True) - self.assertEqual(is_valid_name('Test'), False) - self.assertEqual(is_valid_name('t_____1'), True) - self.assertEqual(is_valid_name('test1'), True) - self.assertEqual(is_valid_name('1test1'), False) + assert is_valid_name(None) == False + assert is_valid_name('test') == True + assert is_valid_name('Test') == False + assert is_valid_name('t_____1') == True + assert is_valid_name('test1') == True + assert is_valid_name('1test1') == False invalid_keywords = cassandra.metadata.cql_keywords - cassandra.metadata.cql_keywords_unreserved for keyword in invalid_keywords: - self.assertEqual(is_valid_name(keyword), False) + assert is_valid_name(keyword) == False class GetReplicasTest(unittest.TestCase): @@ -429,18 +415,18 @@ def _get_replicas(self, token_klass): # tokens match node tokens exactly for token, expected_host in zip(tokens, hosts): replicas = token_map.get_replicas("ks", token) - self.assertEqual(set(replicas), {expected_host}) + assert set(replicas) == {expected_host} # shift the tokens back by one for token, expected_host in zip(tokens, hosts): replicas = token_map.get_replicas("ks", token_klass(token.value - 1)) - self.assertEqual(set(replicas), {expected_host}) + assert set(replicas) == {expected_host} # shift the tokens forward by one for i, token in enumerate(tokens): replicas = token_map.get_replicas("ks", token_klass(token.value + 1)) expected_host = hosts[(i + 1) % len(hosts)] - self.assertEqual(set(replicas), {expected_host}) + assert set(replicas) == {expected_host} def test_murmur3_tokens(self): self._get_replicas(Murmur3Token) @@ -456,7 +442,7 @@ class Murmur3TokensTest(unittest.TestCase): def test_murmur3_init(self): murmur3_token = Murmur3Token(cassandra.metadata.MIN_LONG - 1) - self.assertEqual(str(murmur3_token), '') + assert str(murmur3_token) == '' def test_python_vs_c(self): from cassandra.murmur3 import _murmur3 as mm3_python @@ -467,7 +453,7 @@ def test_python_vs_c(self): for _ in range(iterations): for len in range(0, 32): # zero to one block plus full range of tail lengths key = os.urandom(len) - self.assertEqual(mm3_python(key), mm3_c(key)) + assert mm3_python(key) == mm3_c(key) except ImportError: raise unittest.SkipTest('The cmurmur3 extension is not available') @@ -484,51 +470,51 @@ def test_murmur3_c(self): raise unittest.SkipTest('The cmurmur3 extension is not available') def _verify_hash(self, fn): - self.assertEqual(fn(b'123'), -7468325962851647638) - self.assertEqual(fn(b'\x00\xff\x10\xfa\x99' * 10), 5837342703291459765) - self.assertEqual(fn(b'\xfe' * 8), -8927430733708461935) - self.assertEqual(fn(b'\x10' * 8), 1446172840243228796) - self.assertEqual(fn(str(cassandra.metadata.MAX_LONG).encode()), 7162290910810015547) + assert fn(b'123') == -7468325962851647638 + assert fn(b'\x00\xff\x10\xfa\x99' * 10) == 5837342703291459765 + assert fn(b'\xfe' * 8) == -8927430733708461935 + assert fn(b'\x10' * 8) == 1446172840243228796 + assert fn(str(cassandra.metadata.MAX_LONG).encode()) == 7162290910810015547 class MD5TokensTest(unittest.TestCase): def test_md5_tokens(self): md5_token = MD5Token(cassandra.metadata.MIN_LONG - 1) - self.assertEqual(md5_token.hash_fn('123'), 42767516990368493138776584305024125808) - self.assertEqual(md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)), 28528976619278518853815276204542453639) - self.assertEqual(str(md5_token), '' % -9223372036854775809) + assert md5_token.hash_fn('123') == 42767516990368493138776584305024125808 + assert md5_token.hash_fn(str(cassandra.metadata.MAX_LONG)) == 28528976619278518853815276204542453639 + assert str(md5_token) == '' % -9223372036854775809 class BytesTokensTest(unittest.TestCase): def test_bytes_tokens(self): bytes_token = BytesToken(unhexlify(b'01')) - self.assertEqual(bytes_token.value, b'\x01') - self.assertEqual(str(bytes_token), "" % bytes_token.value) - self.assertEqual(bytes_token.hash_fn('123'), '123') - self.assertEqual(bytes_token.hash_fn(123), 123) - self.assertEqual(bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)), str(cassandra.metadata.MAX_LONG)) + assert bytes_token.value == b'\x01' + assert str(bytes_token) == "" % bytes_token.value + assert bytes_token.hash_fn('123') == '123' + assert bytes_token.hash_fn(123) == 123 + assert bytes_token.hash_fn(str(cassandra.metadata.MAX_LONG)) == str(cassandra.metadata.MAX_LONG) def test_from_string(self): from_unicode = BytesToken.from_string('0123456789abcdef') from_bin = BytesToken.from_string(b'0123456789abcdef') - self.assertEqual(from_unicode, from_bin) - self.assertIsInstance(from_unicode.value, bytes) - self.assertIsInstance(from_bin.value, bytes) + assert from_unicode == from_bin + assert isinstance(from_unicode.value, bytes) + assert isinstance(from_bin.value, bytes) def test_comparison(self): tok = BytesToken.from_string('0123456789abcdef') token_high_order = uint16_unpack(tok.value[0:2]) - self.assertLess(BytesToken(uint16_pack(token_high_order - 1)), tok) - self.assertGreater(BytesToken(uint16_pack(token_high_order + 1)), tok) + assert BytesToken(uint16_pack(token_high_order - 1)) < tok + assert BytesToken(uint16_pack(token_high_order + 1)) > tok def test_comparison_unicode(self): value = b'\'_-()"\xc2\xac' t0 = BytesToken(value) t1 = BytesToken.from_string('00') - self.assertGreater(t0, t1) - self.assertFalse(t0 < t1) + assert t0 > t1 + assert not t0 < t1 class KeyspaceMetadataTest(unittest.TestCase): @@ -541,7 +527,7 @@ def test_export_as_string_user_types(self): keyspace.user_types['c'] = UserType(keyspace_name, 'c', ['one'], ['int']) keyspace.user_types['d'] = UserType(keyspace_name, 'd', ['one'], ['c']) - self.assertEqual("""CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true; + assert """CREATE KEYSPACE test WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '3'} AND durable_writes = true; CREATE TYPE test.c ( one int @@ -560,7 +546,7 @@ def test_export_as_string_user_types(self): one d, two int, three a -);""", keyspace.export_as_string()) +);""" == keyspace.export_as_string() class UserTypesTest(unittest.TestCase): @@ -568,17 +554,17 @@ class UserTypesTest(unittest.TestCase): def test_as_cql_query(self): field_types = ['varint', 'ascii', 'frozen>'] udt = UserType("ks1", "mytype", ["a", "b", "c"], field_types) - self.assertEqual("CREATE TYPE ks1.mytype (a varint, b ascii, c frozen>)", udt.as_cql_query(formatted=False)) + assert "CREATE TYPE ks1.mytype (a varint, b ascii, c frozen>)" == udt.as_cql_query(formatted=False) - self.assertEqual("""CREATE TYPE ks1.mytype ( + assert """CREATE TYPE ks1.mytype ( a varint, b ascii, c frozen> -);""", udt.export_as_string()) +);""" == udt.export_as_string() def test_as_cql_query_name_escaping(self): udt = UserType("MyKeyspace", "MyType", ["AbA", "keyspace"], ['ascii', 'ascii']) - self.assertEqual('CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)', udt.as_cql_query(formatted=False)) + assert 'CREATE TYPE "MyKeyspace"."MyType" ("AbA" ascii, "keyspace" ascii)' == udt.as_cql_query(formatted=False) class UserDefinedFunctionTest(unittest.TestCase): @@ -594,7 +580,7 @@ def test_as_cql_query_removes_frozen(self): "LANGUAGE java " "AS $$return 0;$$" ) - self.assertEqual(expected_result, func.as_cql_query(formatted=False)) + assert expected_result == func.as_cql_query(formatted=False) class UserDefinedAggregateTest(unittest.TestCase): @@ -607,7 +593,7 @@ def test_as_cql_query_removes_frozen(self): "FINALFUNC finalfunc " "INITCOND (0)" ) - self.assertEqual(expected_result, aggregate.as_cql_query(formatted=False)) + assert expected_result == aggregate.as_cql_query(formatted=False) class IndexTest(unittest.TestCase): @@ -622,14 +608,12 @@ def test_build_index_as_cql(self): row = {'index_name': 'index_name_here', 'index_type': 'index_type_here'} index_meta = parser._build_index_metadata(column_meta, row) - self.assertEqual(index_meta.as_cql_query(), - 'CREATE INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here)') + assert index_meta.as_cql_query() == 'CREATE INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here)' row['index_options'] = '{ "class_name": "class_name_here" }' row['index_type'] = 'CUSTOM' index_meta = parser._build_index_metadata(column_meta, row) - self.assertEqual(index_meta.as_cql_query(), - "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'") + assert index_meta.as_cql_query() == "CREATE CUSTOM INDEX index_name_here ON keyspace_name_here.table_name_here (column_name_here) USING 'class_name_here'" class UnicodeIdentifiersTests(unittest.TestCase): @@ -728,63 +712,39 @@ def _function_with_kwargs(self, **kwargs): ) def test_non_monotonic(self): - self.assertNotIn( - 'MONOTONIC', - self._function_with_kwargs( - monotonic=False, - monotonic_on=() - ).export_as_string() - ) + assert 'MONOTONIC' not in self._function_with_kwargs( + monotonic=False, + monotonic_on=() + ).export_as_string() def test_monotonic_all(self): mono_function = self._function_with_kwargs( monotonic=True, monotonic_on=() ) - self.assertIn( - 'MONOTONIC LANG', - mono_function.as_cql_query(formatted=False) - ) - self.assertIn( - 'MONOTONIC\n LANG', - mono_function.as_cql_query(formatted=True) - ) + assert 'MONOTONIC LANG' in mono_function.as_cql_query(formatted=False) + assert 'MONOTONIC\n LANG' in mono_function.as_cql_query(formatted=True) def test_monotonic_one(self): mono_on_function = self._function_with_kwargs( monotonic=False, monotonic_on=('x',) ) - self.assertIn( - 'MONOTONIC ON x LANG', - mono_on_function.as_cql_query(formatted=False) - ) - self.assertIn( - 'MONOTONIC ON x\n LANG', - mono_on_function.as_cql_query(formatted=True) - ) + assert 'MONOTONIC ON x LANG' in mono_on_function.as_cql_query(formatted=False) + assert 'MONOTONIC ON x\n LANG' in mono_on_function.as_cql_query(formatted=True) def test_nondeterministic(self): - self.assertNotIn( - 'DETERMINISTIC', - self._function_with_kwargs( - deterministic=False - ).as_cql_query(formatted=False) - ) + assert 'DETERMINISTIC' not in self._function_with_kwargs( + deterministic=False + ).as_cql_query(formatted=False) def test_deterministic(self): - self.assertIn( - 'DETERMINISTIC', - self._function_with_kwargs( - deterministic=True - ).as_cql_query(formatted=False) - ) - self.assertIn( - 'DETERMINISTIC\n', - self._function_with_kwargs( - deterministic=True - ).as_cql_query(formatted=True) - ) + assert 'DETERMINISTIC' in self._function_with_kwargs( + deterministic=True + ).as_cql_query(formatted=False) + assert 'DETERMINISTIC\n' in self._function_with_kwargs( + deterministic=True + ).as_cql_query(formatted=True) class AggregateToCQLTests(unittest.TestCase): @@ -806,21 +766,16 @@ def _aggregate_with_kwargs(self, **kwargs): ) def test_nondeterministic(self): - self.assertNotIn( - 'DETERMINISTIC', - self._aggregate_with_kwargs( - deterministic=False - ).as_cql_query(formatted=True) - ) + assert 'DETERMINISTIC' not in self._aggregate_with_kwargs( + deterministic=False + ).as_cql_query(formatted=True) def test_deterministic(self): for formatted in (True, False): query = self._aggregate_with_kwargs( deterministic=True ).as_cql_query(formatted=formatted) - self.assertTrue(query.endswith('DETERMINISTIC'), - msg="'DETERMINISTIC' not found in {}".format(query) - ) + assert query.endswith('DETERMINISTIC'), "'DETERMINISTIC' not found in {}".format(query) class HostsTests(unittest.TestCase): @@ -832,12 +787,12 @@ def test_iterate_all_hosts_and_modify(self): metadata.add_or_return_host(Host('dc1.1', SimpleConvictionPolicy)) metadata.add_or_return_host(Host('dc1.2', SimpleConvictionPolicy)) - self.assertEqual(len(metadata.all_hosts()), 2) + assert len(metadata.all_hosts()) == 2 for host in metadata.all_hosts(): # this would previously raise in Py3 metadata.remove_host(host) - self.assertEqual(len(metadata.all_hosts()), 0) + assert len(metadata.all_hosts()) == 0 class MetadataHelpersTest(unittest.TestCase): @@ -856,4 +811,4 @@ def test_strip_frozen(self): ] for argument, expected_result in argument_to_expected_results: result = strip_frozen(argument) - self.assertEqual(result, expected_result, "strip_frozen() arg: {}".format(argument)) + assert result == expected_result, "strip_frozen() arg: {}".format(argument) diff --git a/tests/unit/test_orderedmap.py b/tests/unit/test_orderedmap.py index 318b98a52d..156bbd5f30 100644 --- a/tests/unit/test_orderedmap.py +++ b/tests/unit/test_orderedmap.py @@ -16,6 +16,8 @@ from cassandra.util import OrderedMap, OrderedMapSerializedKey from cassandra.cqltypes import EMPTY, UTF8Type, lookup_casstype +from tests.util import assertListEqual +import pytest class OrderedMapTest(unittest.TestCase): def test_init(self): @@ -23,17 +25,17 @@ def test_init(self): b = OrderedMap([('one', 1), ('three', 3), ('two', 2)]) c = OrderedMap(a) builtin = {'one': 1, 'two': 2, 'three': 3} - self.assertEqual(a, b) - self.assertEqual(a, c) - self.assertEqual(a, builtin) - self.assertEqual(OrderedMap([(1, 1), (1, 2)]), {1: 2}) + assert a == b + assert a == c + assert a == builtin + assert OrderedMap([(1, 1), (1, 2)]) == {1: 2} d = OrderedMap({'': 3}, key1='v1', key2='v2') - self.assertEqual(d[''], 3) - self.assertEqual(d['key1'], 'v1') - self.assertEqual(d['key2'], 'v2') + assert d[''] == 3 + assert d['key1'] == 'v1' + assert d['key2'] == 'v2' - with self.assertRaises(TypeError): + with pytest.raises(TypeError): OrderedMap('too', 'many', 'args') def test_contains(self): @@ -44,41 +46,41 @@ def test_contains(self): om = OrderedMap(zip(keys, range(len(keys)))) for k in keys: - self.assertTrue(k in om) - self.assertFalse(k not in om) + assert k in om + assert not k not in om - self.assertTrue('notthere' not in om) - self.assertFalse('notthere' in om) + assert 'notthere' not in om + assert not 'notthere' in om def test_keys(self): keys = ['first', 'middle', 'last'] om = OrderedMap(zip(keys, range(len(keys)))) - self.assertListEqual(list(om.keys()), keys) + assertListEqual(list(om.keys()), keys) def test_values(self): keys = ['first', 'middle', 'last'] values = list(range(len(keys))) om = OrderedMap(zip(keys, values)) - self.assertListEqual(list(om.values()), values) + assertListEqual(list(om.values()), values) def test_items(self): keys = ['first', 'middle', 'last'] items = list(zip(keys, range(len(keys)))) om = OrderedMap(items) - self.assertListEqual(list(om.items()), items) + assertListEqual(list(om.items()), items) def test_get(self): keys = ['first', 'middle', 'last'] om = OrderedMap(zip(keys, range(len(keys)))) for v, k in enumerate(keys): - self.assertEqual(om.get(k), v) - - self.assertEqual(om.get('notthere', 'default'), 'default') - self.assertIsNone(om.get('notthere')) + assert om.get(k) == v + + assert om.get('notthere', 'default') == 'default' + assert om.get('notthere') is None def test_equal(self): d1 = {'one': 1} @@ -87,26 +89,26 @@ def test_equal(self): om12 = OrderedMap([('one', 1), ('two', 2)]) om21 = OrderedMap([('two', 2), ('one', 1)]) - self.assertEqual(om1, d1) - self.assertEqual(om12, d12) - self.assertEqual(om21, d12) - self.assertNotEqual(om1, om12) - self.assertNotEqual(om12, om1) - self.assertNotEqual(om12, om21) - self.assertNotEqual(om1, d12) - self.assertNotEqual(om12, d1) - self.assertNotEqual(om1, EMPTY) + assert om1 == d1 + assert om12 == d12 + assert om21 == d12 + assert om1 != om12 + assert om12 != om1 + assert om12 != om21 + assert om1 != d12 + assert om12 != d1 + assert om1 != EMPTY - self.assertFalse(OrderedMap([('three', 3), ('four', 4)]) == d12) + assert not OrderedMap([('three', 3), ('four', 4)]) == d12 def test_getitem(self): keys = ['first', 'middle', 'last'] om = OrderedMap(zip(keys, range(len(keys)))) for v, k in enumerate(keys): - self.assertEqual(om[k], v) - - with self.assertRaises(KeyError): + assert om[k] == v + + with pytest.raises(KeyError): om['notthere'] def test_iter(self): @@ -116,16 +118,17 @@ def test_iter(self): om = OrderedMap(items) itr = iter(om) - self.assertEqual(sum([1 for _ in itr]), len(keys)) - self.assertRaises(StopIteration, next, itr) + assert sum([1 for _ in itr]) == len(keys) + with pytest.raises(StopIteration): + next(itr) - self.assertEqual(list(iter(om)), keys) - self.assertEqual(list(om.items()), items) - self.assertEqual(list(om.values()), values) + assert list(iter(om)) == keys + assert list(om.items()) == items + assert list(om.values()) == values def test_len(self): - self.assertEqual(len(OrderedMap()), 0) - self.assertEqual(len(OrderedMap([(1, 1)])), 1) + assert len(OrderedMap()) == 0 + assert len(OrderedMap([(1, 1)])) == 1 def test_mutable_keys(self): d = {'1': 1} @@ -136,35 +139,36 @@ def test_strings(self): # changes in 3.x d = {'map': 'inner'} s = set([1, 2, 3]) - self.assertEqual(repr(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])), - "OrderedMap([('two', 2), ('one', 1), (%r, 'value'), (%r, 'another')])" % (d, s)) + assert repr(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])) == "OrderedMap([('two', 2), ('one', 1), (%r, 'value'), (%r, 'another')])" % (d, s) - self.assertEqual(str(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])), - "{'two': 2, 'one': 1, %r: 'value', %r: 'another'}" % (d, s)) + assert str(OrderedMap([('two', 2), ('one', 1), (d, 'value'), (s, 'another')])) == "{'two': 2, 'one': 1, %r: 'value', %r: 'another'}" % (d, s) def test_popitem(self): item = (1, 2) om = OrderedMap((item,)) - self.assertEqual(om.popitem(), item) - self.assertRaises(KeyError, om.popitem) + assert om.popitem() == item + with pytest.raises(KeyError): + om.popitem() def test_delitem(self): om = OrderedMap({1: 1, 2: 2}) - self.assertRaises(KeyError, om.__delitem__, 3) + with pytest.raises(KeyError): + om.__delitem__(3) del om[1] - self.assertEqual(om, {2: 2}) + assert om == {2: 2} del om[2] - self.assertFalse(om) + assert not om - self.assertRaises(KeyError, om.__delitem__, 1) + with pytest.raises(KeyError): + om.__delitem__(1) class OrderedMapSerializedKeyTest(unittest.TestCase): def test_init(self): om = OrderedMapSerializedKey(UTF8Type, 3) - self.assertEqual(om, {}) + assert om == {} def test_normalized_lookup(self): key_type = lookup_casstype('MapType(UTF8Type, Int32Type)') @@ -177,6 +181,6 @@ def test_normalized_lookup(self): # type lookup is normalized by key_type # PYTHON-231 - self.assertIs(om[{'one': 1}], om[{u'one': 1}]) - self.assertIs(om[{'two': 2}], om[{u'two': 2}]) - self.assertIsNot(om[{'one': 1}], om[{'two': 2}]) + assert om[{'one': 1}] is om[{u'one': 1}] + assert om[{'two': 2}] is om[{u'two': 2}] + assert om[{'one': 1}] is not om[{'two': 2}] diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 9c557c0208..5416ac461d 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +import pytest from cassandra.encoder import Encoder from cassandra.protocol import ColumnMetadata @@ -21,36 +22,38 @@ from cassandra.cqltypes import Int32Type from cassandra.util import OrderedDict +from tests.util import assertListEqual + class ParamBindingTest(unittest.TestCase): def test_bind_sequence(self): result = bind_params("%s %s %s", (1, "a", 2.0), Encoder()) - self.assertEqual(result, "1 'a' 2.0") + assert result == "1 'a' 2.0" def test_bind_map(self): result = bind_params("%(a)s %(b)s %(c)s", dict(a=1, b="a", c=2.0), Encoder()) - self.assertEqual(result, "1 'a' 2.0") + assert result == "1 'a' 2.0" def test_sequence_param(self): result = bind_params("%s", (ValueSequence((1, "a", 2.0)),), Encoder()) - self.assertEqual(result, "(1, 'a', 2.0)") + assert result == "(1, 'a', 2.0)" def test_generator_param(self): result = bind_params("%s", ((i for i in range(3)),), Encoder()) - self.assertEqual(result, "[0, 1, 2]") + assert result == "[0, 1, 2]" def test_none_param(self): result = bind_params("%s", (None,), Encoder()) - self.assertEqual(result, "NULL") + assert result == "NULL" def test_list_collection(self): result = bind_params("%s", (['a', 'b', 'c'],), Encoder()) - self.assertEqual(result, "['a', 'b', 'c']") + assert result == "['a', 'b', 'c']" def test_set_collection(self): result = bind_params("%s", (set(['a', 'b']),), Encoder()) - self.assertIn(result, ("{'a', 'b'}", "{'b', 'a'}")) + assert result in ("{'a', 'b'}", "{'b', 'a'}") def test_map_collection(self): vals = OrderedDict() @@ -58,15 +61,15 @@ def test_map_collection(self): vals['b'] = 'b' vals['c'] = 'c' result = bind_params("%s", (vals,), Encoder()) - self.assertEqual(result, "{'a': 'a', 'b': 'b', 'c': 'c'}") + assert result == "{'a': 'a', 'b': 'b', 'c': 'c'}" def test_quote_escaping(self): result = bind_params("%s", ("""'ef''ef"ef""ef'""",), Encoder()) - self.assertEqual(result, """'''ef''''ef"ef""ef'''""") + assert result == """'''ef''''ef"ef""ef'''""" def test_float_precision(self): f = 3.4028234663852886e+38 - self.assertEqual(float(bind_params("%s", (f,), Encoder())), f) + assert float(bind_params("%s", (f,), Encoder())) == f class BoundStatementTestV3(unittest.TestCase): @@ -89,25 +92,21 @@ def setUpClass(cls): def test_invalid_argument_type(self): values = (0, 0, 0, 'string not int') - try: + with pytest.raises(TypeError) as exc: self.bound.bind(values) - except TypeError as e: - self.assertIn('v0', str(e)) - self.assertIn('Int32Type', str(e)) - self.assertIn('str', str(e)) - else: - self.fail('Passed invalid type but exception was not thrown') + e = exc.value + assert 'v0' in str(e) + assert 'Int32Type' in str(e) + assert 'str' in str(e) values = (['1', '2'], 0, 0, 0) - try: + with pytest.raises(TypeError) as exc: self.bound.bind(values) - except TypeError as e: - self.assertIn('rk0', str(e)) - self.assertIn('Int32Type', str(e)) - self.assertIn('list', str(e)) - else: - self.fail('Passed invalid type but exception was not thrown') + e = exc.value + assert 'rk0' in str(e) + assert 'Int32Type' in str(e) + assert 'list' in str(e) def test_inherit_fetch_size(self): keyspace = 'keyspace1' @@ -128,29 +127,35 @@ def test_inherit_fetch_size(self): result_metadata_id=None) prepared_statement.fetch_size = 1234 bound_statement = BoundStatement(prepared_statement=prepared_statement) - self.assertEqual(1234, bound_statement.fetch_size) + assert 1234 == bound_statement.fetch_size def test_too_few_parameters_for_routing_key(self): - self.assertRaises(ValueError, self.prepared.bind, (1,)) + with pytest.raises(ValueError): + self.prepared.bind((1,)) bound = self.prepared.bind((1, 2)) - self.assertEqual(bound.keyspace, 'keyspace') + assert bound.keyspace == 'keyspace' def test_dict_missing_routing_key(self): - self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) - self.assertRaises(KeyError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) + with pytest.raises(KeyError): + self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + with pytest.raises(KeyError): + self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) def test_missing_value(self): - self.assertRaises(KeyError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0}) + with pytest.raises(KeyError): + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) def test_extra_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': 0, 'should_not_be_here': 123}) # okay to have extra keys in dict - self.assertEqual(self.bound.values, [b'\x00' * 4] * 4) # four encoded zeros - self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, 0, 123)) + assert self.bound.values == [b'\x00' * 4] * 4 # four encoded zeros + with pytest.raises(ValueError): + self.bound.bind((0, 0, 0, 0, 123)) def test_values_none(self): # should have values - self.assertRaises(ValueError, self.bound.bind, None) + with pytest.raises(ValueError): + self.bound.bind(None) # prepared statement with no values prepared_statement = PreparedStatement(column_metadata=[], @@ -162,20 +167,22 @@ def test_values_none(self): result_metadata=None, result_metadata_id=None) bound = prepared_statement.bind(None) - self.assertListEqual(bound.values, []) + assertListEqual(bound.values, []) def test_bind_none(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': None}) - self.assertEqual(self.bound.values[-1], None) + assert self.bound.values[-1] == None old_values = self.bound.values self.bound.bind((0, 0, 0, None)) - self.assertIsNot(self.bound.values, old_values) - self.assertEqual(self.bound.values[-1], None) + assert self.bound.values is not old_values + assert self.bound.values[-1] == None def test_unset_value(self): - self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) - self.assertRaises(ValueError, self.bound.bind, (0, 0, 0, UNSET_VALUE)) + with pytest.raises(ValueError): + self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) + with pytest.raises(ValueError): + self.bound.bind((0, 0, 0, UNSET_VALUE)) class BoundStatementTestV4(BoundStatementTestV3): @@ -184,25 +191,27 @@ class BoundStatementTestV4(BoundStatementTestV3): def test_dict_missing_routing_key(self): # in v4 it implicitly binds UNSET_VALUE for missing items, # UNSET_VALUE is ValueError for routing keys - self.assertRaises(ValueError, self.bound.bind, {'rk0': 0, 'ck0': 0, 'v0': 0}) - self.assertRaises(ValueError, self.bound.bind, {'rk1': 0, 'ck0': 0, 'v0': 0}) + with pytest.raises(ValueError): + self.bound.bind({'rk0': 0, 'ck0': 0, 'v0': 0}) + with pytest.raises(ValueError): + self.bound.bind({'rk1': 0, 'ck0': 0, 'v0': 0}) def test_missing_value(self): # in v4 missing values are UNSET_VALUE self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0}) - self.assertEqual(self.bound.values[-1], UNSET_VALUE) + assert self.bound.values[-1] == UNSET_VALUE old_values = self.bound.values self.bound.bind((0, 0, 0)) - self.assertIsNot(self.bound.values, old_values) - self.assertEqual(self.bound.values[-1], UNSET_VALUE) + assert self.bound.values is not old_values + assert self.bound.values[-1] == UNSET_VALUE def test_unset_value(self): self.bound.bind({'rk0': 0, 'rk1': 0, 'ck0': 0, 'v0': UNSET_VALUE}) - self.assertEqual(self.bound.values[-1], UNSET_VALUE) + assert self.bound.values[-1] == UNSET_VALUE self.bound.bind((0, 0, 0, UNSET_VALUE)) - self.assertEqual(self.bound.values[-1], UNSET_VALUE) + assert self.bound.values[-1] == UNSET_VALUE class BoundStatementTestV5(BoundStatementTestV4): diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index e7757aedfc..c98511ab34 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -48,16 +48,24 @@ def test_non_implemented(self): host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") - self.assertRaises(NotImplementedError, policy.distance, host) - self.assertRaises(NotImplementedError, policy.populate, None, host) - self.assertRaises(NotImplementedError, policy.make_query_plan) - self.assertRaises(NotImplementedError, policy.on_up, host) - self.assertRaises(NotImplementedError, policy.on_down, host) - self.assertRaises(NotImplementedError, policy.on_add, host) - self.assertRaises(NotImplementedError, policy.on_remove, host) + with pytest.raises(NotImplementedError): + policy.distance(host) + with pytest.raises(NotImplementedError): + policy.populate(None, host) + with pytest.raises(NotImplementedError): + policy.make_query_plan() + with pytest.raises(NotImplementedError): + policy.on_up(host) + with pytest.raises(NotImplementedError): + policy.on_down(host) + with pytest.raises(NotImplementedError): + policy.on_add(host) + with pytest.raises(NotImplementedError): + policy.on_remove(host) def test_instance_check(self): - self.assertRaises(TypeError, Cluster, load_balancing_policy=RoundRobinPolicy) + with pytest.raises(TypeError): + Cluster(load_balancing_policy=RoundRobinPolicy) class RoundRobinPolicyTest(unittest.TestCase): @@ -67,7 +75,7 @@ def test_basic(self): policy = RoundRobinPolicy() policy.populate(None, hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), hosts) + assert sorted(qplan) == hosts def test_multiple_query_plans(self): hosts = [0, 1, 2, 3] @@ -75,13 +83,13 @@ def test_multiple_query_plans(self): policy.populate(None, hosts) for i in range(20): qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), hosts) + assert sorted(qplan) == hosts def test_single_host(self): policy = RoundRobinPolicy() policy.populate(None, [0]) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, [0]) + assert qplan == [0] def test_status_updates(self): hosts = [0, 1, 2, 3] @@ -92,7 +100,7 @@ def test_status_updates(self): policy.on_up(4) policy.on_add(5) qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), [2, 3, 4, 5]) + assert sorted(qplan) == [2, 3, 4, 5] def test_thread_safety(self): hosts = range(100) @@ -102,7 +110,7 @@ def test_thread_safety(self): def check_query_plan(): for i in range(100): qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), list(hosts)) + assert sorted(qplan) == list(hosts) threads = [Thread(target=check_query_plan) for i in range(4)] for t in threads: @@ -161,8 +169,7 @@ def host_down(): else: sys.setswitchinterval(original_interval) - if errors: - self.fail("Saw errors: %s" % (errors,)) + assert not errors, "Saw errors: %s" % (errors,) def test_no_live_nodes(self): """ @@ -176,7 +183,7 @@ def test_no_live_nodes(self): policy.on_down(i) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, []) + assert qplan == [] @pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) class TestRackOrDCAwareRoundRobinPolicy: @@ -591,14 +598,14 @@ def get_replicas(keyspace, packed_key): replicas = get_replicas(None, struct.pack('>i', i)) other = set(h for h in hosts if h not in replicas) - self.assertEqual(replicas, qplan[:2]) - self.assertEqual(other, set(qplan[2:])) + assert replicas == qplan[:2] + assert other == set(qplan[2:]) # Should use the secondary policy for i in range(4): qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan), set(hosts)) + assert set(qplan) == set(hosts) def test_wrap_dc_aware(self): cluster = Mock(spec=Cluster) @@ -631,17 +638,17 @@ def get_replicas(keyspace, packed_key): replicas = get_replicas(None, struct.pack('>i', i)) # first should be the only local replica - self.assertIn(qplan[0], replicas) - self.assertEqual(qplan[0].datacenter, "dc1") + assert qplan[0] in replicas + assert qplan[0].datacenter == "dc1" # then the local non-replica - self.assertNotIn(qplan[1], replicas) - self.assertEqual(qplan[1].datacenter, "dc1") + assert qplan[1] not in replicas + assert qplan[1].datacenter == "dc1" # then one of the remotes (used_hosts_per_remote_dc is 1, so we # shouldn't see two remotes) - self.assertEqual(qplan[2].datacenter, "dc2") - self.assertEqual(3, len(qplan)) + assert qplan[2].datacenter == "dc2" + assert 3 == len(qplan) class FakeCluster: def __init__(self): @@ -661,20 +668,20 @@ def test_get_distance(self): policy.populate(self.FakeCluster(), [host]) - self.assertEqual(policy.distance(host), HostDistance.LOCAL) + assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) remote_host.set_location_info("dc2", "rack1") - self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + assert policy.distance(remote_host) == HostDistance.IGNORED # dc2 isn't registered in the policy's live_hosts dict policy._child_policy.used_hosts_per_remote_dc = 1 - self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + assert policy.distance(remote_host) == HostDistance.IGNORED # make sure the policy has both dcs registered policy.populate(self.FakeCluster(), [host, remote_host]) - self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + assert policy.distance(remote_host) == HostDistance.REMOTE # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED @@ -682,7 +689,7 @@ def test_get_distance(self): second_remote_host.set_location_info("dc2", "rack1") policy.populate(self.FakeCluster(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) - self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) + assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) def test_status_updates(self): """ @@ -710,21 +717,21 @@ def test_status_updates(self): # we now have two local hosts and two remote hosts in separate dcs qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) - self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + assert set(qplan[:2]) == set([hosts[1], new_local_host]) + assert set(qplan[2:]) == set([hosts[3], new_remote_host]) # since we have hosts in dc9000, the distance shouldn't be IGNORED - self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + assert policy.distance(new_remote_host) == HostDistance.REMOTE policy.on_down(new_local_host) policy.on_down(hosts[1]) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + assert set(qplan) == set([hosts[3], new_remote_host]) policy.on_down(new_remote_host) policy.on_down(hosts[3]) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, []) + assert qplan == [] def test_statement_keyspace(self): hosts = [Host(DefaultEndPoint(str(i)), SimpleConvictionPolicy) for i in range(4)] @@ -749,8 +756,8 @@ def test_statement_keyspace(self): routing_key = 'routing_key' query = Statement(routing_key=routing_key) qplan = list(policy.make_query_plan(keyspace, query)) - self.assertEqual(hosts, qplan) - self.assertEqual(cluster.metadata.get_replicas.call_count, 0) + assert hosts == qplan + assert cluster.metadata.get_replicas.call_count == 0 child_policy.make_query_plan.assert_called_once_with(keyspace, query) # working keyspace, no statement @@ -759,7 +766,7 @@ def test_statement_keyspace(self): routing_key = 'routing_key' query = Statement(routing_key=routing_key) qplan = list(policy.make_query_plan(keyspace, query)) - self.assertEqual(replicas + hosts[:2], qplan) + assert replicas + hosts[:2] == qplan cluster.metadata.get_replicas.assert_called_with(keyspace, routing_key) # statement keyspace, no working @@ -769,7 +776,7 @@ def test_statement_keyspace(self): routing_key = 'routing_key' query = Statement(routing_key=routing_key, keyspace=statement_keyspace) qplan = list(policy.make_query_plan(working_keyspace, query)) - self.assertEqual(replicas + hosts[:2], qplan) + assert replicas + hosts[:2] == qplan cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) # both keyspaces set, statement keyspace used for routing @@ -779,7 +786,7 @@ def test_statement_keyspace(self): routing_key = 'routing_key' query = Statement(routing_key=routing_key, keyspace=statement_keyspace) qplan = list(policy.make_query_plan(working_keyspace, query)) - self.assertEqual(replicas + hosts[:2], qplan) + assert replicas + hosts[:2] == qplan cluster.metadata.get_replicas.assert_called_with(statement_keyspace, routing_key) def test_shuffles_if_given_keyspace_and_routing_key(self): @@ -840,15 +847,15 @@ def _assert_shuffle(self, patched_shuffle, keyspace, routing_key): query = Statement(routing_key=routing_key) qplan = list(policy.make_query_plan(keyspace, query)) if keyspace is None or routing_key is None: - self.assertEqual(hosts, qplan) - self.assertEqual(cluster.metadata.get_replicas.call_count, 0) + assert hosts == qplan + assert cluster.metadata.get_replicas.call_count == 0 child_policy.make_query_plan.assert_called_once_with(keyspace, query) - self.assertEqual(patched_shuffle.call_count, 0) + assert patched_shuffle.call_count == 0 else: - self.assertEqual(set(replicas), set(qplan[:2])) - self.assertEqual(hosts[:2], qplan[2:]) + assert set(replicas) == set(qplan[:2]) + assert hosts[:2] == qplan[2:] child_policy.make_query_plan.assert_called_once_with(keyspace, query) - self.assertEqual(patched_shuffle.call_count, 1) + assert patched_shuffle.call_count == 1 class ConvictionPolicyTest(unittest.TestCase): @@ -858,8 +865,10 @@ def test_not_implemented(self): """ conviction_policy = ConvictionPolicy(1) - self.assertRaises(NotImplementedError, conviction_policy.add_failure, 1) - self.assertRaises(NotImplementedError, conviction_policy.reset) + with pytest.raises(NotImplementedError): + conviction_policy.add_failure(1) + with pytest.raises(NotImplementedError): + conviction_policy.reset() class SimpleConvictionPolicyTest(unittest.TestCase): @@ -869,8 +878,8 @@ def test_basic_responses(self): """ conviction_policy = SimpleConvictionPolicy(1) - self.assertEqual(conviction_policy.add_failure(1), True) - self.assertEqual(conviction_policy.reset(), None) + assert conviction_policy.add_failure(1) == True + assert conviction_policy.reset() == None class ReconnectionPolicyTest(unittest.TestCase): @@ -880,7 +889,8 @@ def test_basic_responses(self): """ policy = ReconnectionPolicy() - self.assertRaises(NotImplementedError, policy.new_schedule) + with pytest.raises(NotImplementedError): + policy.new_schedule() class ConstantReconnectionPolicyTest(unittest.TestCase): @@ -890,20 +900,21 @@ def test_bad_vals(self): Test initialization values """ - self.assertRaises(ValueError, ConstantReconnectionPolicy, -1, 0) + with pytest.raises(ValueError): + ConstantReconnectionPolicy(-1, 0) def test_schedule(self): """ Test ConstantReconnectionPolicy schedule """ - delay = 2 + configured_delay = 2 max_attempts = 100 - policy = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) + policy = ConstantReconnectionPolicy(delay=configured_delay, max_attempts=max_attempts) schedule = list(policy.new_schedule()) - self.assertEqual(len(schedule), max_attempts) + assert len(schedule) == max_attempts for i, delay in enumerate(schedule): - self.assertEqual(delay, delay) + assert delay == configured_delay def test_schedule_negative_max_attempts(self): """ @@ -913,11 +924,8 @@ def test_schedule_negative_max_attempts(self): delay = 2 max_attempts = -100 - try: + with pytest.raises(ValueError): ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) - self.fail('max_attempts should throw ValueError when negative') - except ValueError: - pass def test_schedule_infinite_attempts(self): delay = 2 @@ -925,19 +933,23 @@ def test_schedule_infinite_attempts(self): crp = ConstantReconnectionPolicy(delay=delay, max_attempts=max_attempts) # this is infinite. we'll just verify one more than default for _, d in zip(range(65), crp.new_schedule()): - self.assertEqual(d, delay) + assert d == delay class ExponentialReconnectionPolicyTest(unittest.TestCase): def _assert_between(self, value, min, max): - self.assertTrue(min <= value <= max) + assert min <= value <= max def test_bad_vals(self): - self.assertRaises(ValueError, ExponentialReconnectionPolicy, -1, 0) - self.assertRaises(ValueError, ExponentialReconnectionPolicy, 0, -1) - self.assertRaises(ValueError, ExponentialReconnectionPolicy, 9000, 1) - self.assertRaises(ValueError, ExponentialReconnectionPolicy, 1, 2, -1) + with pytest.raises(ValueError): + ExponentialReconnectionPolicy(-1, 0) + with pytest.raises(ValueError): + ExponentialReconnectionPolicy(0, -1) + with pytest.raises(ValueError): + ExponentialReconnectionPolicy(9000, 1) + with pytest.raises(ValueError): + ExponentialReconnectionPolicy(1, 2, -1) def test_schedule_no_max(self): base_delay = 2.0 @@ -947,7 +959,7 @@ def test_schedule_no_max(self): sched_slice = list(islice(policy.new_schedule(), 0, test_iter)) self._assert_between(sched_slice[0], base_delay*0.85, base_delay*1.15) self._assert_between(sched_slice[-1], max_delay*0.85, max_delay*1.15) - self.assertEqual(len(sched_slice), test_iter) + assert len(sched_slice) == test_iter def test_schedule_with_max(self): base_delay = 2.0 @@ -955,7 +967,7 @@ def test_schedule_with_max(self): max_attempts = 64 policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) schedule = list(policy.new_schedule()) - self.assertEqual(len(schedule), max_attempts) + assert len(schedule) == max_attempts for i, delay in enumerate(schedule): if i == 0: self._assert_between(delay, base_delay*0.85, base_delay*1.15) @@ -972,7 +984,7 @@ def test_schedule_exactly_one_attempt(self): policy = ExponentialReconnectionPolicy( base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts ) - self.assertEqual(len(list(policy.new_schedule())), 1) + assert len(list(policy.new_schedule())) == 1 def test_schedule_overflow(self): """ @@ -996,7 +1008,7 @@ def test_schedule_overflow(self): policy = ExponentialReconnectionPolicy(base_delay=base_delay, max_delay=max_delay, max_attempts=max_attempts) schedule = list(policy.new_schedule()) for number in schedule: - self.assertLessEqual(number, sys.float_info.max) + assert number <= sys.float_info.max def test_schedule_with_jitter(self): """ @@ -1030,29 +1042,29 @@ def test_read_timeout(self): retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=1, received_responses=2, data_retrieved=True, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # if we didn't get enough responses, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=1, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # if we got enough responses, but also got a data response, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=2, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # we got enough responses but no data response, so retry retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=2, data_retrieved=False, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ONE def test_write_timeout(self): policy = RetryPolicy() @@ -1061,22 +1073,22 @@ def test_write_timeout(self): retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # if it's not a BATCH_LOG write, don't retry it retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # retry BATCH_LOG writes regardless of received responses retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, required_responses=10000, received_responses=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ONE def test_unavailable(self): """ @@ -1087,20 +1099,20 @@ def test_unavailable(self): retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETRY_NEXT_HOST + assert consistency == None retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=10000, alive_replicas=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY_NEXT_HOST) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETRY_NEXT_HOST + assert consistency == None class FallthroughRetryPolicyTest(unittest.TestCase): @@ -1115,26 +1127,26 @@ def test_read_timeout(self): retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=1, received_responses=2, data_retrieved=True, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=1, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=2, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=2, data_retrieved=False, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None def test_write_timeout(self): policy = FallthroughRetryPolicy() @@ -1142,20 +1154,20 @@ def test_write_timeout(self): retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, required_responses=10000, received_responses=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None def test_unavailable(self): policy = FallthroughRetryPolicy() @@ -1163,20 +1175,20 @@ def test_unavailable(self): retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=1, alive_replicas=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=10000, alive_replicas=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None class DowngradingConsistencyRetryPolicyTest(unittest.TestCase): @@ -1188,50 +1200,50 @@ def test_read_timeout(self): retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=1, received_responses=2, data_retrieved=True, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # if we didn't get enough responses, retry at a lower consistency retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=4, received_responses=3, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ConsistencyLevel.THREE) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.THREE # if we didn't get enough responses, retry at a lower consistency retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=3, received_responses=2, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ConsistencyLevel.TWO) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.TWO # retry consistency level goes down based on the # of recv'd responses retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=3, received_responses=1, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ConsistencyLevel.ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.ONE # if we got no responses, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=3, received_responses=0, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # if we got enough response but no data, retry retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=3, received_responses=3, data_retrieved=False, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ONE # if we got enough responses, but also got a data response, rethrow retry, consistency = policy.on_read_timeout( query=None, consistency=ONE, required_responses=2, received_responses=2, data_retrieved=True, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None def test_write_timeout(self): policy = DowngradingConsistencyRetryPolicy() @@ -1240,41 +1252,41 @@ def test_write_timeout(self): retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.SIMPLE, required_responses=1, received_responses=2, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None for write_type in (WriteType.SIMPLE, WriteType.BATCH, WriteType.COUNTER): # ignore failures if at least one response (replica persisted) retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=write_type, required_responses=1, received_responses=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.IGNORE) + assert retry == RetryPolicy.IGNORE # retrhow if we can't be sure we have a replica retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=write_type, required_responses=1, received_responses=0, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) + assert retry == RetryPolicy.RETHROW # downgrade consistency level on unlogged batch writes retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.UNLOGGED_BATCH, required_responses=3, received_responses=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ConsistencyLevel.ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.ONE # retry batch log writes at the same consistency level retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=WriteType.BATCH_LOG, required_responses=3, received_responses=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ONE # timeout on an unknown write_type retry, consistency = policy.on_write_timeout( query=None, consistency=ONE, write_type=None, required_responses=1, received_responses=2, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None def test_unavailable(self): policy = DowngradingConsistencyRetryPolicy() @@ -1282,14 +1294,14 @@ def test_unavailable(self): # if this is the second or greater attempt, rethrow retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=3, alive_replicas=1, retry_num=1) - self.assertEqual(retry, RetryPolicy.RETHROW) - self.assertEqual(consistency, None) + assert retry == RetryPolicy.RETHROW + assert consistency == None # downgrade consistency on unavailable exceptions retry, consistency = policy.on_unavailable( query=None, consistency=ONE, required_replicas=3, alive_replicas=1, retry_num=0) - self.assertEqual(retry, RetryPolicy.RETRY) - self.assertEqual(consistency, ConsistencyLevel.ONE) + assert retry == RetryPolicy.RETRY + assert consistency == ConsistencyLevel.ONE class ExponentialRetryPolicyTest(unittest.TestCase): @@ -1319,9 +1331,9 @@ def test_hosts_with_hostname(self): policy.populate(None, [host]) qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), [host]) + assert sorted(qplan) == [host] - self.assertEqual(policy.distance(host), HostDistance.LOCAL) + assert policy.distance(host) == HostDistance.LOCAL def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] @@ -1330,9 +1342,9 @@ def test_hosts_with_socket_hostname(self): policy.populate(None, [host]) qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), [host]) + assert sorted(qplan) == [host] - self.assertEqual(policy.distance(host), HostDistance.LOCAL) + assert policy.distance(host) == HostDistance.LOCAL class AddressTranslatorTest(unittest.TestCase): @@ -1345,8 +1357,8 @@ def test_ec2_multi_region_translator(self, *_): ec2t = EC2MultiRegionTranslator() addr = '127.0.0.1' translated = ec2t.translate(addr) - self.assertIsNot(translated, addr) # verifies that the resolver path is followed - self.assertEqual(translated, addr) # and that it resolves to the same address + assert translated is not addr # verifies that the resolver path is followed + assert translated == addr # and that it resolves to the same address class HostFilterPolicyInitTest(unittest.TestCase): @@ -1356,8 +1368,8 @@ def setUp(self): Mock(name='predicate')) def _check_init(self, hfp): - self.assertIs(hfp._child_policy, self.child_policy) - self.assertIsInstance(hfp._hosts_lock, LockType) + assert hfp._child_policy is self.child_policy + assert isinstance(hfp._hosts_lock, LockType) # we can't use a simple assertIs because we wrap the function arg0, arg1 = Mock(name='arg0'), Mock(name='arg1') @@ -1380,7 +1392,7 @@ def test_immutable_predicate(self): expected_message_regex = "can't set attribute" hfp = HostFilterPolicy(child_policy=Mock(name='child_policy'), predicate=Mock(name='predicate')) - with self.assertRaisesRegex(AttributeError, expected_message_regex): + with pytest.raises(AttributeError, match=expected_message_regex): hfp.predicate = object() @@ -1408,7 +1420,7 @@ def _check_host_triggered_method(self, policy, name): # method calls the child policy's method... child_policy_method.assert_called_once_with(arg, kw=kwarg) # and returns its return value - self.assertIs(result, child_policy_method.return_value) + assert result is child_policy_method.return_value def test_defer_on_up_to_child_policy(self): self._check_host_triggered_method(self.passthrough_hfp, 'on_up') @@ -1456,10 +1468,8 @@ def setUp(self): self.accepted_host = Host(DefaultEndPoint('acceptme'), conviction_policy_factory=Mock()) def test_ignored_with_filter(self): - self.assertEqual(self.hfp.distance(self.ignored_host), - HostDistance.IGNORED) - self.assertNotEqual(self.hfp.distance(self.accepted_host), - HostDistance.IGNORED) + assert self.hfp.distance(self.ignored_host) == HostDistance.IGNORED + assert self.hfp.distance(self.accepted_host) != HostDistance.IGNORED def test_accepted_filter_defers_to_child_policy(self): self.hfp._child_policy.distance.side_effect = distances = Mock(), Mock() @@ -1467,9 +1477,9 @@ def test_accepted_filter_defers_to_child_policy(self): # getting the distance for an ignored host shouldn't affect subsequent results self.hfp.distance(self.ignored_host) # first call of _child_policy with count() side effect - self.assertEqual(self.hfp.distance(self.accepted_host), distances[0]) + assert self.hfp.distance(self.accepted_host) == distances[0] # second call of _child_policy with count() side effect - self.assertEqual(self.hfp.distance(self.accepted_host), distances[1]) + assert self.hfp.distance(self.accepted_host) == distances[1] class HostFilterPolicyPopulateTest(unittest.TestCase): @@ -1496,10 +1506,7 @@ def test_child_is_populated_with_filtered_hosts(self): ['acceptme0', 'acceptme1']) hfp.populate(mock_cluster, hosts) hfp._child_policy.populate.assert_called_once() - self.assertEqual( - hfp._child_policy.populate.call_args[1]['hosts'], - ['acceptme0', 'acceptme1'] - ) + assert hfp._child_policy.populate.call_args[1]['hosts'] == ['acceptme0', 'acceptme1'] class HostFilterPolicyQueryPlanTest(unittest.TestCase): @@ -1523,7 +1530,7 @@ def test_query_plan_deferred_to_child(self): working_keyspace=working_keyspace, query=query ) - self.assertEqual(qp, hfp._child_policy.make_query_plan.return_value) + assert qp == hfp._child_policy.make_query_plan.return_value def test_wrap_token_aware(self): cluster = Mock(spec=Cluster) @@ -1553,9 +1560,9 @@ def get_replicas(keyspace, packed_key): query_plan = hfp.make_query_plan("keyspace", mocked_query) # First the not filtered replica, and then the rest of the allowed hosts ordered query_plan = list(query_plan) - self.assertEqual(query_plan[0], Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy)) - self.assertEqual(set(query_plan[1:]),{Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)}) + assert query_plan[0] == Host(DefaultEndPoint("127.0.0.2"), SimpleConvictionPolicy) + assert set(query_plan[1:]) == {Host(DefaultEndPoint("127.0.0.3"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.5"), SimpleConvictionPolicy)} def test_create_whitelist(self): cluster = Mock(spec=Cluster) @@ -1577,5 +1584,5 @@ def test_create_whitelist(self): mocked_query = Mock() query_plan = hfp.make_query_plan("keyspace", mocked_query) # Only the filtered replicas should be allowed - self.assertEqual(set(query_plan), {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), - Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)}) + assert set(query_plan) == {Host(DefaultEndPoint("127.0.0.1"), SimpleConvictionPolicy), + Host(DefaultEndPoint("127.0.0.4"), SimpleConvictionPolicy)} diff --git a/tests/unit/test_protocol.py b/tests/unit/test_protocol.py index 907f62f2bb..57261654df 100644 --- a/tests/unit/test_protocol.py +++ b/tests/unit/test_protocol.py @@ -26,6 +26,7 @@ from cassandra.query import BatchType from cassandra.marshal import uint32_unpack from cassandra.cluster import ContinuousPagingOptions +import pytest class MessageTest(unittest.TestCase): @@ -88,10 +89,7 @@ def test_query_message(self): self._check_calls(io, [(b'\x00\x00\x00\x01',), (b'a',), (b'\x00\x03',), (b'\x00\x00\x00\x00',)]) def _check_calls(self, io, expected): - self.assertEqual( - tuple(c[1] for c in io.write.mock_calls), - tuple(expected) - ) + assert tuple(c[1] for c in io.write.mock_calls) == tuple(expected) def test_continuous_paging(self): """ @@ -112,22 +110,23 @@ def test_continuous_paging(self): io = Mock() for version in [version for version in ProtocolVersion.SUPPORTED_VERSIONS if not ProtocolVersion.has_continuous_paging_support(version)]: - self.assertRaises(UnsupportedOperation, message.send_body, io, version) + with pytest.raises(UnsupportedOperation): + message.send_body(io, version) io.reset_mock() message.send_body(io, ProtocolVersion.DSE_V1) # continuous paging adds two write calls to the buffer - self.assertEqual(len(io.write.mock_calls), 6) + assert len(io.write.mock_calls) == 6 # Check that the appropriate flag is set to True - self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG, 0) - self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG, 0) - self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_PAGING_STATE_FLAG, 0) - self.assertEqual(uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG, _PAGING_OPTIONS_FLAG) + assert uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_SERIAL_CONSISTENCY_FLAG == 0 + assert uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGE_SIZE_FLAG == 0 + assert uint32_unpack(io.write.mock_calls[3][1][0]) & _WITH_PAGING_STATE_FLAG == 0 + assert uint32_unpack(io.write.mock_calls[3][1][0]) & _PAGING_OPTIONS_FLAG == _PAGING_OPTIONS_FLAG # Test max_pages and max_pages_per_second are correctly written - self.assertEqual(uint32_unpack(io.write.mock_calls[4][1][0]), max_pages) - self.assertEqual(uint32_unpack(io.write.mock_calls[5][1][0]), max_pages_per_second) + assert uint32_unpack(io.write.mock_calls[4][1][0]) == max_pages + assert uint32_unpack(io.write.mock_calls[5][1][0]) == max_pages_per_second def test_prepare_flag(self): """ @@ -144,9 +143,9 @@ def test_prepare_flag(self): for version in ProtocolVersion.SUPPORTED_VERSIONS: message.send_body(io, version) if ProtocolVersion.uses_prepare_flags(version): - self.assertEqual(len(io.write.mock_calls), 3) + assert len(io.write.mock_calls) == 3 else: - self.assertEqual(len(io.write.mock_calls), 2) + assert len(io.write.mock_calls) == 2 io.reset_mock() def test_prepare_flag_with_keyspace(self): @@ -164,7 +163,7 @@ def test_prepare_flag_with_keyspace(self): (b'ks',), ]) else: - with self.assertRaises(UnsupportedOperation): + with pytest.raises(UnsupportedOperation): message.send_body(io, version) io.reset_mock() @@ -172,7 +171,7 @@ def test_keyspace_flag_raises_before_v5(self): keyspace_message = QueryMessage('a', consistency_level=3, keyspace='ks') io = Mock(name='io') - with self.assertRaisesRegex(UnsupportedOperation, 'Keyspaces.*set'): + with pytest.raises(UnsupportedOperation, match='Keyspaces.*set'): keyspace_message.send_body(io, protocol_version=4) io.assert_not_called() diff --git a/tests/unit/test_protocol_features.py b/tests/unit/test_protocol_features.py index bcf874f68f..89b568ea68 100644 --- a/tests/unit/test_protocol_features.py +++ b/tests/unit/test_protocol_features.py @@ -22,6 +22,6 @@ class OptionsHolder(object): protocol_features = ProtocolFeatures.parse_from_supported(OptionsHolder().options) - self.assertEqual(protocol_features.rate_limit_error, 123) - self.assertEqual(protocol_features.shard_id, 0) - self.assertEqual(protocol_features.sharding_info, None) + assert protocol_features.rate_limit_error == 123 + assert protocol_features.shard_id == 0 + assert protocol_features.sharding_info is None diff --git a/tests/unit/test_query.py b/tests/unit/test_query.py index 8a3f00fa9d..29c800b99c 100644 --- a/tests/unit/test_query.py +++ b/tests/unit/test_query.py @@ -30,26 +30,26 @@ def test_clear(self): batch = BatchStatement() batch.add(ss) - self.assertTrue(batch._statements_and_parameters) - self.assertEqual(batch.keyspace, keyspace) - self.assertEqual(batch.routing_key, routing_key) - self.assertEqual(batch.custom_payload, custom_payload) + assert batch._statements_and_parameters + assert batch.keyspace == keyspace + assert batch.routing_key == routing_key + assert batch.custom_payload == custom_payload batch.clear() - self.assertFalse(batch._statements_and_parameters) - self.assertIsNone(batch.keyspace) - self.assertIsNone(batch.routing_key) - self.assertFalse(batch.custom_payload) + assert not batch._statements_and_parameters + assert batch.keyspace is None + assert batch.routing_key is None + assert not batch.custom_payload batch.add(ss) def test_clear_empty(self): batch = BatchStatement() batch.clear() - self.assertFalse(batch._statements_and_parameters) - self.assertIsNone(batch.keyspace) - self.assertIsNone(batch.routing_key) - self.assertFalse(batch.custom_payload) + assert not batch._statements_and_parameters + assert batch.keyspace is None + assert batch.routing_key is None + assert not batch.custom_payload batch.add('something') @@ -60,11 +60,11 @@ def test_add_all(self): batch.add_all(statements, parameters) bound_statements = [t[1] for t in batch._statements_and_parameters] str_parameters = [str(i) for i in range(10)] - self.assertEqual(bound_statements, str_parameters) + assert bound_statements == str_parameters def test_len(self): for n in 0, 10, 100: batch = BatchStatement() batch.add_all(statements=['%s'] * n, parameters=[(i,) for i in range(n)]) - self.assertEqual(len(batch), n) + assert len(batch) == n diff --git a/tests/unit/test_response_future.py b/tests/unit/test_response_future.py index 8226cea440..bcca28ac73 100644 --- a/tests/unit/test_response_future.py +++ b/tests/unit/test_response_future.py @@ -31,6 +31,8 @@ from cassandra.policies import RetryPolicy, ExponentialBackoffRetryPolicy from cassandra.pool import NoConnectionsAvailable from cassandra.query import SimpleStatement +from tests.util import assertEqual, assertIsInstance +import pytest class ResponseFutureTests(unittest.TestCase): @@ -81,7 +83,7 @@ def test_result_message(self): expected_result = (object(), object()) rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) result = rf.result()[0] - self.assertEqual(result, expected_result) + assert result == expected_result def test_unknown_result_class(self): session = self.make_session() @@ -92,7 +94,8 @@ def test_unknown_result_class(self): rf = self.make_response_future(session) rf.send_request() rf._set_result(None, None, None, object()) - self.assertRaises(ConnectionException, rf.result) + with pytest.raises(ConnectionException): + rf.result() def test_set_keyspace_result(self): session = self.make_session() @@ -104,7 +107,7 @@ def test_set_keyspace_result(self): results="keyspace1") rf._set_result(None, None, None, result) rf._set_keyspace_completed({}) - self.assertFalse(rf.result()) + assert not rf.result() def test_schema_change_result(self): session = self.make_session() @@ -126,7 +129,7 @@ def test_other_result_message_kind(self): rf.send_request() result = Mock(spec=ResultMessage, kind=999, results=[1, 2, 3]) rf._set_result(None, None, None, result) - self.assertEqual(rf.result()[0], result) + assert rf.result()[0] == result def test_heartbeat_defunct_deadlock(self): """ @@ -159,7 +162,8 @@ def test_heartbeat_defunct_deadlock(self): # Simulate ResponseFuture timing out rf._on_timeout() - self.assertRaisesRegex(OperationTimedOut, "Connection defunct by heartbeat", rf.result) + with pytest.raises(OperationTimedOut, match="Connection defunct by heartbeat"): + rf.result() def test_read_timeout_error_message(self): session = self.make_session() @@ -173,7 +177,8 @@ def test_read_timeout_error_message(self): "received_responses":1, "consistency": 1}) rf._set_result(None, None, None, result) - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() def test_write_timeout_error_message(self): session = self.make_session() @@ -186,7 +191,8 @@ def test_write_timeout_error_message(self): result = Mock(spec=WriteTimeoutErrorMessage, info={"write_type": 1, "required_responses":2, "received_responses":1, "consistency": 1}) rf._set_result(None, None, None, result) - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() def test_unavailable_error_message(self): session = self.make_session() @@ -201,7 +207,8 @@ def test_unavailable_error_message(self): result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) rf._set_result(None, None, None, result) - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() def test_request_error_with_prepare_message(self): session = self.make_session() @@ -216,14 +223,14 @@ def test_request_error_with_prepare_message(self): result = Mock(spec=OverloadedErrorMessage) result.to_exception.return_value = result rf._set_result(None, None, None, result) - self.assertIsInstance(rf._final_exception, OverloadedErrorMessage) + assert isinstance(rf._final_exception, OverloadedErrorMessage) rf = ResponseFuture(session, message, query, 1, retry_policy=retry_policy) rf._query_retries = 1 rf.send_request() result = Mock(spec=ConnectionException) rf._set_result(None, None, None, result) - self.assertIsInstance(rf._final_exception, ConnectionException) + assert isinstance(rf._final_exception, ConnectionException) def test_retry_policy_says_ignore(self): session = self.make_session() @@ -237,7 +244,7 @@ def test_retry_policy_says_ignore(self): result = Mock(spec=UnavailableErrorMessage, info={}) rf._set_result(None, None, None, result) - self.assertFalse(rf.result()) + assert not rf.result() def test_retry_policy_says_retry(self): session = self.make_session() @@ -264,7 +271,7 @@ def test_retry_policy_says_retry(self): rf._set_result(host, None, None, result) rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, True, host) - self.assertEqual(1, rf._query_retries) + assert 1 == rf._query_retries connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) @@ -292,7 +299,7 @@ def test_retry_with_different_host(self): rf.session._pools.get.assert_called_once_with('ip1') pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY, keyspace=ANY, table=ANY) connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) - self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) + assert ConsistencyLevel.QUORUM == rf.message.consistency_level result = Mock(spec=OverloadedErrorMessage, info={}) host = Mock() @@ -300,7 +307,7 @@ def test_retry_with_different_host(self): rf.session.cluster.scheduler.schedule.assert_called_once_with(ANY, rf._retry_task, False, host) # query_retries does get incremented for Overloaded/Bootstrapping errors (since 3.18) - self.assertEqual(1, rf._query_retries) + assert 1 == rf._query_retries connection = Mock(spec=Connection) pool.borrow_connection.return_value = (connection, 2) @@ -313,7 +320,7 @@ def test_retry_with_different_host(self): connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[]) # the consistency level should be the same - self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level) + assert ConsistencyLevel.QUORUM == rf.message.consistency_level def test_all_retries_fail(self): session = self.make_session() @@ -344,7 +351,8 @@ def test_all_retries_fail(self): rf.session.cluster.scheduler.schedule.assert_called_with(ANY, rf._retry_task, False, host) rf._retry_task(False, host) - self.assertRaises(NoHostAvailable, rf.result) + with pytest.raises(NoHostAvailable): + rf.result() def test_exponential_retry_policy_fail(self): session = self.make_session() @@ -376,7 +384,8 @@ def test_all_pools_shutdown(self): rf = ResponseFuture(session, Mock(), Mock(), 1) rf.send_request() - self.assertRaises(NoHostAvailable, rf.result) + with pytest.raises(NoHostAvailable): + rf.result() def test_first_pool_shutdown(self): session = self.make_basic_session() @@ -395,7 +404,7 @@ def test_first_pool_shutdown(self): rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) result = rf.result()[0] - self.assertEqual(result, expected_result) + assert result == expected_result def test_timeout_getting_connection_from_pool(self): session = self.make_basic_session() @@ -418,10 +427,10 @@ def test_timeout_getting_connection_from_pool(self): expected_result = (object(), object()) rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) - self.assertEqual(rf.result()[0], expected_result) + assert rf.result()[0] == expected_result # make sure the exception is recorded correctly - self.assertEqual(rf._errors, {'ip1': exc}) + assert rf._errors == {'ip1': exc} def test_callback(self): session = self.make_session() @@ -437,12 +446,12 @@ def test_callback(self): rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) result = rf.result()[0] - self.assertEqual(result, expected_result) + assert result == expected_result callback.assert_called_once_with([expected_result], arg, **kwargs) # this should get called immediately now that the result is set - rf.add_callback(self.assertEqual, [expected_result]) + rf.add_callback(assertEqual, [expected_result]) def test_errback(self): session = self.make_session() @@ -457,16 +466,17 @@ def test_errback(self): rf._query_retries = 1 rf.send_request() - rf.add_errback(self.assertIsInstance, Exception) + rf.add_errback(assertIsInstance, Exception) result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) result.to_exception.return_value = Exception() rf._set_result(None, None, None, result) - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() # this should get called immediately now that the error is set - rf.add_errback(self.assertIsInstance, Exception) + rf.add_errback(assertIsInstance, Exception) def test_multiple_callbacks(self): session = self.make_session() @@ -487,7 +497,7 @@ def test_multiple_callbacks(self): rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) result = rf.result()[0] - self.assertEqual(result, expected_result) + assert result == expected_result callback.assert_called_once_with([expected_result], arg, **kwargs) callback2.assert_called_once_with([expected_result], arg2, **kwargs2) @@ -521,7 +531,8 @@ def test_multiple_errbacks(self): result.to_exception.return_value = expected_exception rf._set_result(None, None, None, result) rf._event.set() - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() callback.assert_called_once_with(expected_exception, arg, **kwargs) callback2.assert_called_once_with(expected_exception, arg2, **kwargs2) @@ -537,14 +548,15 @@ def test_add_callbacks(self): rf.send_request() rf.add_callbacks( - callback=self.assertEqual, callback_args=([{'col': 'val'}],), - errback=self.assertIsInstance, errback_args=(Exception,)) + callback=assertEqual, callback_args=([{'col': 'val'}],), + errback=assertIsInstance, errback_args=(Exception,)) result = Mock(spec=UnavailableErrorMessage, info={"required_replicas":2, "alive_replicas": 1, "consistency": 1}) result.to_exception.return_value = Exception() rf._set_result(None, None, None, result) - self.assertRaises(Exception, rf.result) + with pytest.raises(Exception): + rf.result() # test callback rf = ResponseFuture(session, message, query, 1) @@ -556,10 +568,10 @@ def test_add_callbacks(self): kwargs = {'one': 1, 'two': 2} rf.add_callbacks( callback=callback, callback_args=(arg,), callback_kwargs=kwargs, - errback=self.assertIsInstance, errback_args=(Exception,)) + errback=assertIsInstance, errback_args=(Exception,)) rf._set_result(None, None, None, self.make_mock_response(expected_result[0], expected_result[1])) - self.assertEqual(rf.result()[0], expected_result) + assert rf.result()[0] == expected_result callback.assert_called_once_with([expected_result], arg, **kwargs) @@ -582,11 +594,11 @@ def test_prepared_query_not_found(self): result = Mock(spec=PreparedQueryNotFound, info='a' * 16) rf._set_result(None, None, None, result) - self.assertTrue(session.submit.call_args) + assert session.submit.call_args args, kwargs = session.submit.call_args - self.assertEqual(rf._reprepare, args[-5]) - self.assertIsInstance(args[-4], PrepareMessage) - self.assertEqual(args[-4].query, "SELECT * FROM foobar") + assert rf._reprepare == args[-5] + assert isinstance(args[-4], PrepareMessage) + assert args[-4].query == "SELECT * FROM foobar" def test_prepared_query_not_found_bad_keyspace(self): session = self.make_session() @@ -606,7 +618,8 @@ def test_prepared_query_not_found_bad_keyspace(self): result = Mock(spec=PreparedQueryNotFound, info='a' * 16) rf._set_result(None, None, None, result) - self.assertRaises(ValueError, rf.result) + with pytest.raises(ValueError): + rf.result() def test_repeat_orig_query_after_succesful_reprepare(self): query_id = b'abc123' # Just a random binary string so we don't hit id mismatch exception @@ -650,7 +663,8 @@ def test_timeout_does_not_release_stream_id(self): rf._on_timeout() pool.return_connection.assert_called_once_with(connection, stream_was_orphaned=True) - self.assertRaisesRegex(OperationTimedOut, "Client request timeout", rf.result) + with pytest.raises(OperationTimedOut, match="Client request timeout"): + rf.result() assert len(connection.request_ids) == 0, \ "Request IDs should be empty but it's not: {}".format(connection.request_ids) diff --git a/tests/unit/test_resultset.py b/tests/unit/test_resultset.py index 7ff6352394..80e9c21ff9 100644 --- a/tests/unit/test_resultset.py +++ b/tests/unit/test_resultset.py @@ -18,6 +18,9 @@ from cassandra.cluster import ResultSet from cassandra.query import named_tuple_factory, dict_factory, tuple_factory +from tests.util import assertListEqual +import pytest + class ResultSetTests(unittest.TestCase): @@ -25,7 +28,7 @@ def test_iter_non_paged(self): expected = list(range(10)) rs = ResultSet(Mock(has_more_pages=False), expected) itr = iter(rs) - self.assertListEqual(list(itr), expected) + assertListEqual(list(itr), expected) def test_iter_paged(self): expected = list(range(10)) @@ -35,7 +38,7 @@ def test_iter_paged(self): itr = iter(rs) # this is brittle, depends on internal impl details. Would like to find a better way type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, False)) # after init to avoid side effects being consumed by init - self.assertListEqual(list(itr), expected) + assertListEqual(list(itr), expected) def test_iter_paged_with_empty_pages(self): expected = list(range(10)) @@ -48,15 +51,15 @@ def test_iter_paged_with_empty_pages(self): ] rs = ResultSet(response_future, []) itr = iter(rs) - self.assertListEqual(list(itr), expected) + assertListEqual(list(itr), expected) def test_list_non_paged(self): # list access on RS for backwards-compatibility expected = list(range(10)) rs = ResultSet(Mock(has_more_pages=False), expected) for i in range(10): - self.assertEqual(rs[i], expected[i]) - self.assertEqual(list(rs), expected) + assert rs[i] == expected[i] + assert list(rs) == expected def test_list_paged(self): # list access on RS for backwards-compatibility @@ -66,16 +69,16 @@ def test_list_paged(self): rs = ResultSet(response_future, expected[:5]) # this is brittle, depends on internal impl details. Would like to find a better way type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode - self.assertEqual(rs[9], expected[9]) - self.assertEqual(list(rs), expected) + assert rs[9] == expected[9] + assert list(rs) == expected def test_has_more_pages(self): response_future = Mock() response_future.has_more_pages.side_effect = PropertyMock(side_effect=(True, False)) rs = ResultSet(response_future, []) type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) # after init to avoid side effects being consumed by init - self.assertTrue(rs.has_more_pages) - self.assertFalse(rs.has_more_pages) + assert rs.has_more_pages + assert not rs.has_more_pages def test_iterate_then_index(self): # RuntimeError if indexing with no pages @@ -83,15 +86,15 @@ def test_iterate_then_index(self): rs = ResultSet(Mock(has_more_pages=False), expected) itr = iter(rs) # before consuming - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): rs[0] list(itr) # after consuming - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): rs[0] - self.assertFalse(rs) - self.assertFalse(list(rs)) + assert not rs + assert not list(rs) # RuntimeError if indexing during or after pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) @@ -100,17 +103,17 @@ def test_iterate_then_index(self): type(response_future).has_more_pages = PropertyMock(side_effect=(True, False)) itr = iter(rs) # before consuming - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): rs[0] for row in itr: # while consuming - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): rs[0] # after consuming - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): rs[0] - self.assertFalse(rs) - self.assertFalse(list(rs)) + assert not rs + assert not list(rs) def test_index_list_mode(self): # no pages @@ -118,13 +121,13 @@ def test_index_list_mode(self): rs = ResultSet(Mock(has_more_pages=False), expected) # index access before iteration causes list to be materialized - self.assertEqual(rs[0], expected[0]) + assert rs[0] == expected[0] # resusable iteration - self.assertListEqual(list(rs), expected) - self.assertListEqual(list(rs), expected) + assertListEqual(list(rs), expected) + assertListEqual(list(rs), expected) - self.assertTrue(rs) + assert rs # pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) @@ -133,13 +136,13 @@ def test_index_list_mode(self): # this is brittle, depends on internal impl details. Would like to find a better way type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # First two True are consumed on check entering list mode # index access before iteration causes list to be materialized - self.assertEqual(rs[0], expected[0]) - self.assertEqual(rs[9], expected[9]) + assert rs[0] == expected[0] + assert rs[9] == expected[9] # resusable iteration - self.assertListEqual(list(rs), expected) - self.assertListEqual(list(rs), expected) + assertListEqual(list(rs), expected) + assertListEqual(list(rs), expected) - self.assertTrue(rs) + assert rs def test_eq(self): # no pages @@ -147,12 +150,12 @@ def test_eq(self): rs = ResultSet(Mock(has_more_pages=False), expected) # eq before iteration causes list to be materialized - self.assertEqual(rs, expected) + assert rs == expected # results can be iterated or indexed once we're materialized - self.assertListEqual(list(rs), expected) - self.assertEqual(rs[9], expected[9]) - self.assertTrue(rs) + assertListEqual(list(rs), expected) + assert rs[9] == expected[9] + assert rs # pages response_future = Mock(has_more_pages=True, _continuous_paging_session=None) @@ -160,56 +163,56 @@ def test_eq(self): rs = ResultSet(response_future, expected[:5]) type(response_future).has_more_pages = PropertyMock(side_effect=(True, True, True, False)) # eq before iteration causes list to be materialized - self.assertEqual(rs, expected) + assert rs == expected # results can be iterated or indexed once we're materialized - self.assertListEqual(list(rs), expected) - self.assertEqual(rs[9], expected[9]) - self.assertTrue(rs) + assertListEqual(list(rs), expected) + assert rs[9] == expected[9] + assert rs def test_bool(self): - self.assertFalse(ResultSet(Mock(has_more_pages=False), [])) - self.assertTrue(ResultSet(Mock(has_more_pages=False), [1])) + assert not ResultSet(Mock(has_more_pages=False), []) + assert ResultSet(Mock(has_more_pages=False), [1]) def test_was_applied(self): # unknown row factory raises - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ResultSet(Mock(), []).was_applied response_future = Mock(row_factory=named_tuple_factory) # no row - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ResultSet(response_future, []).was_applied # too many rows - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ResultSet(response_future, [tuple(), tuple()]).was_applied # various internal row factories for row_factory in (named_tuple_factory, tuple_factory): for applied in (True, False): rs = ResultSet(Mock(row_factory=row_factory), [(applied,)]) - self.assertEqual(rs.was_applied, applied) + assert rs.was_applied == applied row_factory = dict_factory for applied in (True, False): rs = ResultSet(Mock(row_factory=row_factory), [{'[applied]': applied}]) - self.assertEqual(rs.was_applied, applied) + assert rs.was_applied == applied def test_one(self): # no pages first, second = Mock(), Mock() rs = ResultSet(Mock(has_more_pages=False), [first, second]) - self.assertEqual(rs.one(), first) + assert rs.one() == first def test_all(self): first, second = Mock(), Mock() rs1 = ResultSet(Mock(has_more_pages=False), [first, second]) rs2 = ResultSet(Mock(has_more_pages=False), [first, second]) - self.assertEqual(rs1.all(), list(rs2)) + assert rs1.all() == list(rs2) @patch('cassandra.cluster.warn') def test_indexing_deprecation(self, mocked_warn): @@ -217,9 +220,8 @@ def test_indexing_deprecation(self, mocked_warn): # pre-Py3.0 for some reason first, second = Mock(), Mock() rs = ResultSet(Mock(has_more_pages=False), [first, second]) - self.assertEqual(rs[0], first) - self.assertEqual(len(mocked_warn.mock_calls), 1) + assert rs[0] == first + assert len(mocked_warn.mock_calls) == 1 index_warning_args = tuple(mocked_warn.mock_calls[0])[1] - self.assertIn('indexing support will be removed in 4.0', - str(index_warning_args[0])) - self.assertIs(index_warning_args[1], DeprecationWarning) + assert 'indexing support will be removed in 4.0' in str(index_warning_args[0]) + assert index_warning_args[1] is DeprecationWarning diff --git a/tests/unit/test_row_factories.py b/tests/unit/test_row_factories.py index 70691ad8fd..7787f1d271 100644 --- a/tests/unit/test_row_factories.py +++ b/tests/unit/test_row_factories.py @@ -61,13 +61,13 @@ def test_creation_warning_on_long_column_list(self): with warnings.catch_warnings(record=True) as w: rows = named_tuple_factory(self.long_colnames, self.long_rows) - self.assertEqual(len(w), 1) + assert len(w) == 1 warning = w[0] - self.assertIn('pseudo_namedtuple_factory', str(warning)) - self.assertIn('3.7', str(warning)) + assert 'pseudo_namedtuple_factory' in str(warning) + assert '3.7' in str(warning) for r in rows: - self.assertEqual(r.col0, self.long_rows[0][0]) + assert r.col0 == self.long_rows[0][0] def test_creation_no_warning_on_short_column_list(self): """ @@ -81,7 +81,7 @@ def test_creation_no_warning_on_short_column_list(self): """ with warnings.catch_warnings(record=True) as w: rows = named_tuple_factory(self.short_colnames, self.short_rows) - self.assertEqual(len(w), 0) + assert len(w) == 0 # check that this is a real namedtuple - self.assertTrue(hasattr(rows[0], '_fields')) - self.assertIsInstance(rows[0], tuple) + assert hasattr(rows[0], '_fields') + assert isinstance(rows[0], tuple) diff --git a/tests/unit/test_segment.py b/tests/unit/test_segment.py index 0d0f146c16..bfc273db05 100644 --- a/tests/unit/test_segment.py +++ b/tests/unit/test_segment.py @@ -19,6 +19,7 @@ from cassandra import DriverException from cassandra.segment import Segment, CrcException from cassandra.connection import segment_codec_no_compression, segment_codec_lz4 +import pytest def to_bits(b): @@ -50,10 +51,8 @@ def _header_to_bits(data): def test_encode_uncompressed_header(self): buffer = BytesIO() segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True) - self.assertEqual(buffer.tell(), 6) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - "00000000000110010" + "1" + "000000") + assert buffer.tell() == 6 + assert self._header_to_bits(buffer.getvalue()) == "00000000000110010" + "1" + "000000" @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_encode_compressed_header(self): @@ -61,45 +60,37 @@ def test_encode_compressed_header(self): compressed_length = len(segment_codec_lz4.compress(self.small_msg)) segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) - self.assertEqual(buffer.tell(), 8) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000") + assert buffer.tell() == 8 + assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "00000000000110010" + "1" + "00000" def test_encode_uncompressed_header_with_max_payload(self): buffer = BytesIO() segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, True) - self.assertEqual(buffer.tell(), 6) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - "11111111111111111" + "1" + "000000") + assert buffer.tell() == 6 + assert self._header_to_bits(buffer.getvalue()) == "11111111111111111" + "1" + "000000" def test_encode_header_fails_if_payload_too_big(self): buffer = BytesIO() for codec in [c for c in [segment_codec_no_compression, segment_codec_lz4] if c is not None]: - with self.assertRaises(DriverException): + with pytest.raises(DriverException): codec.encode_header(buffer, len(self.large_msg), -1, False) def test_encode_uncompressed_header_not_self_contained_msg(self): buffer = BytesIO() # simulate the first chunk with the max size segment_codec_no_compression.encode_header(buffer, len(self.max_msg), -1, False) - self.assertEqual(buffer.tell(), 6) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - ("11111111111111111" - "0" # not self contained - "000000")) + assert buffer.tell() == 6 + assert self._header_to_bits(buffer.getvalue()) == ("11111111111111111" + "0" # not self contained + "000000") @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_encode_compressed_header_with_max_payload(self): buffer = BytesIO() compressed_length = len(segment_codec_lz4.compress(self.max_msg)) segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), True) - self.assertEqual(buffer.tell(), 8) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000") + assert buffer.tell() == 8 + assert self._header_to_bits(buffer.getvalue()) == "{:017b}".format(compressed_length) + "11111111111111111" + "1" + "00000" @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_encode_compressed_header_not_self_contained_msg(self): @@ -107,22 +98,20 @@ def test_encode_compressed_header_not_self_contained_msg(self): # simulate the first chunk with the max size compressed_length = len(segment_codec_lz4.compress(self.max_msg)) segment_codec_lz4.encode_header(buffer, compressed_length, len(self.max_msg), False) - self.assertEqual(buffer.tell(), 8) - self.assertEqual( - self._header_to_bits(buffer.getvalue()), - ("{:017b}".format(compressed_length) + - "11111111111111111" - "0" # not self contained - "00000")) + assert buffer.tell() == 8 + assert self._header_to_bits(buffer.getvalue()) == ("{:017b}".format(compressed_length) + + "11111111111111111" + "0" # not self contained + "00000") def test_decode_uncompressed_header(self): buffer = BytesIO() segment_codec_no_compression.encode_header(buffer, len(self.small_msg), -1, True) buffer.seek(0) header = segment_codec_no_compression.decode_header(buffer) - self.assertEqual(header.uncompressed_payload_length, -1) - self.assertEqual(header.payload_length, len(self.small_msg)) - self.assertEqual(header.is_self_contained, True) + assert header.uncompressed_payload_length == -1 + assert header.payload_length == len(self.small_msg) + assert header.is_self_contained == True @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_decode_compressed_header(self): @@ -131,9 +120,9 @@ def test_decode_compressed_header(self): segment_codec_lz4.encode_header(buffer, compressed_length, len(self.small_msg), True) buffer.seek(0) header = segment_codec_lz4.decode_header(buffer) - self.assertEqual(header.uncompressed_payload_length, len(self.small_msg)) - self.assertEqual(header.payload_length, compressed_length) - self.assertEqual(header.is_self_contained, True) + assert header.uncompressed_payload_length == len(self.small_msg) + assert header.payload_length == compressed_length + assert header.is_self_contained == True def test_decode_header_fails_if_corrupted(self): buffer = BytesIO() @@ -143,7 +132,7 @@ def test_decode_header_fails_if_corrupted(self): buffer.write(b'0') buffer.seek(0) - with self.assertRaises(CrcException): + with pytest.raises(CrcException): segment_codec_no_compression.decode_header(buffer) def test_decode_uncompressed_self_contained_segment(self): @@ -154,10 +143,10 @@ def test_decode_uncompressed_self_contained_segment(self): header = segment_codec_no_compression.decode_header(buffer) segment = segment_codec_no_compression.decode(buffer, header) - self.assertEqual(header.is_self_contained, True) - self.assertEqual(header.uncompressed_payload_length, -1) - self.assertEqual(header.payload_length, len(self.small_msg)) - self.assertEqual(segment.payload, self.small_msg) + assert header.is_self_contained == True + assert header.uncompressed_payload_length == -1 + assert header.payload_length == len(self.small_msg) + assert segment.payload == self.small_msg @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_decode_compressed_self_contained_segment(self): @@ -168,10 +157,10 @@ def test_decode_compressed_self_contained_segment(self): header = segment_codec_lz4.decode_header(buffer) segment = segment_codec_lz4.decode(buffer, header) - self.assertEqual(header.is_self_contained, True) - self.assertEqual(header.uncompressed_payload_length, len(self.small_msg)) - self.assertGreater(header.uncompressed_payload_length, header.payload_length) - self.assertEqual(segment.payload, self.small_msg) + assert header.is_self_contained == True + assert header.uncompressed_payload_length == len(self.small_msg) + assert header.uncompressed_payload_length > header.payload_length + assert segment.payload == self.small_msg def test_decode_multi_segments(self): buffer = BytesIO() @@ -186,9 +175,9 @@ def test_decode_multi_segments(self): headers.append(segment_codec_no_compression.decode_header(buffer)) segments.append(segment_codec_no_compression.decode(buffer, headers[1])) - self.assertTrue(all([h.is_self_contained is False for h in headers])) + assert all([h.is_self_contained is False for h in headers]) decoded_msg = segments[0].payload + segments[1].payload - self.assertEqual(decoded_msg, self.large_msg) + assert decoded_msg == self.large_msg @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') def test_decode_fails_if_corrupted(self): @@ -198,7 +187,7 @@ def test_decode_fails_if_corrupted(self): buffer.write(b'0') buffer.seek(0) header = segment_codec_lz4.decode_header(buffer) - with self.assertRaises(CrcException): + with pytest.raises(CrcException): segment_codec_lz4.decode(buffer, header) @unittest.skipUnless(segment_codec_lz4, ' lz4 not installed') @@ -208,6 +197,6 @@ def test_decode_tiny_msg_not_compressed(self): buffer.seek(0) header = segment_codec_lz4.decode_header(buffer) segment = segment_codec_lz4.decode(buffer, header) - self.assertEqual(header.uncompressed_payload_length, 0) - self.assertEqual(header.payload_length, 1) - self.assertEqual(segment.payload, b'b') + assert header.uncompressed_payload_length == 0 + assert header.payload_length == 1 + assert segment.payload == b'b' diff --git a/tests/unit/test_shard_aware.py b/tests/unit/test_shard_aware.py index fe7b95edba..8b34eb2578 100644 --- a/tests/unit/test_shard_aware.py +++ b/tests/unit/test_shard_aware.py @@ -46,12 +46,12 @@ class OptionsHolder(object): } shard_id, shard_info = ProtocolFeatures.parse_sharding_info(OptionsHolder().options) - self.assertEqual(shard_id, 1) - self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"a").value), 4) - self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"b").value), 6) - self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"c").value), 6) - self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value), 4) - self.assertEqual(shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value), 2) + assert shard_id == 1 + assert shard_info.shard_id_from_token(Murmur3Token.from_key(b"a").value) == 4 + assert shard_info.shard_id_from_token(Murmur3Token.from_key(b"b").value) == 6 + assert shard_info.shard_id_from_token(Murmur3Token.from_key(b"c").value) == 6 + assert shard_info.shard_id_from_token(Murmur3Token.from_key(b"e").value) == 4 + assert shard_info.shard_id_from_token(Murmur3Token.from_key(b"100000").value) == 2 def test_advanced_shard_aware_port(self): """ diff --git a/tests/unit/test_sortedset.py b/tests/unit/test_sortedset.py index 49c3658df8..071907d53e 100644 --- a/tests/unit/test_sortedset.py +++ b/tests/unit/test_sortedset.py @@ -13,10 +13,13 @@ # limitations under the License. import unittest +import pytest from cassandra.util import sortedset from cassandra.cqltypes import EMPTY +from tests.util import assertListEqual + from datetime import datetime from itertools import permutations @@ -25,11 +28,11 @@ def test_init(self): input = [5, 4, 3, 2, 1, 1, 1] expected = sorted(set(input)) ss = sortedset(input) - self.assertEqual(len(ss), len(expected)) - self.assertEqual(list(ss), expected) + assert len(ss) == len(expected) + assert list(ss) == expected def test_repr(self): - self.assertEqual(repr(sortedset([1, 2, 3, 4])), "SortedSet([1, 2, 3, 4])") + assert repr(sortedset([1, 2, 3, 4])) == "SortedSet([1, 2, 3, 4])" def test_contains(self): input = [5, 4, 3, 2, 1, 1, 1] @@ -37,24 +40,24 @@ def test_contains(self): ss = sortedset(input) for i in expected: - self.assertTrue(i in ss) - self.assertFalse(i not in ss) + assert i in ss + assert not i not in ss hi = max(expected)+1 lo = min(expected)-1 - self.assertFalse(hi in ss) - self.assertFalse(lo in ss) + assert not hi in ss + assert not lo in ss def test_mutable_contents(self): ba = bytearray(b'some data here') ss = sortedset([ba, ba]) - self.assertEqual(list(ss), [ba]) + assert list(ss) == [ba] def test_clear(self): ss = sortedset([1, 2, 3]) ss.clear() - self.assertEqual(len(ss), 0) + assert len(ss) == 0 def test_equal(self): s1 = set([1]) @@ -62,15 +65,15 @@ def test_equal(self): ss1 = sortedset(s1) ss12 = sortedset(s12) - self.assertEqual(ss1, s1) - self.assertEqual(ss12, s12) - self.assertEqual(ss12, s12) - self.assertEqual(ss1.__eq__(None), NotImplemented) - self.assertNotEqual(ss1, ss12) - self.assertNotEqual(ss12, ss1) - self.assertNotEqual(ss1, s12) - self.assertNotEqual(ss12, s1) - self.assertNotEqual(ss1, EMPTY) + assert ss1 == s1 + assert ss12 == s12 + assert ss12 == s12 + assert ss1.__eq__(None) == NotImplemented + assert ss1 != ss12 + assert ss12 != ss1 + assert ss1 != s12 + assert ss12 != s1 + assert ss1 != EMPTY def test_copy(self): class comparable(object): @@ -80,9 +83,9 @@ def __lt__(self, other): o = comparable() ss = sortedset([comparable(), o]) ss2 = ss.copy() - self.assertNotEqual(id(ss), id(ss2)) - self.assertTrue(o in ss) - self.assertTrue(o in ss2) + assert id(ss) != id(ss2) + assert o in ss + assert o in ss2 def test_isdisjoint(self): # set, ss @@ -92,25 +95,25 @@ def test_isdisjoint(self): ss13 = sortedset([1, 3]) ss3 = sortedset([3]) # s ss disjoint - self.assertTrue(s2.isdisjoint(ss1)) - self.assertTrue(s2.isdisjoint(ss13)) + assert s2.isdisjoint(ss1) + assert s2.isdisjoint(ss13) # s ss not disjoint - self.assertFalse(s12.isdisjoint(ss1)) - self.assertFalse(s12.isdisjoint(ss13)) + assert not s12.isdisjoint(ss1) + assert not s12.isdisjoint(ss13) # ss s disjoint - self.assertTrue(ss1.isdisjoint(s2)) - self.assertTrue(ss13.isdisjoint(s2)) + assert ss1.isdisjoint(s2) + assert ss13.isdisjoint(s2) # ss s not disjoint - self.assertFalse(ss1.isdisjoint(s12)) - self.assertFalse(ss13.isdisjoint(s12)) + assert not ss1.isdisjoint(s12) + assert not ss13.isdisjoint(s12) # ss ss disjoint - self.assertTrue(ss1.isdisjoint(ss3)) - self.assertTrue(ss3.isdisjoint(ss1)) + assert ss1.isdisjoint(ss3) + assert ss3.isdisjoint(ss1) # ss ss not disjoint - self.assertFalse(ss1.isdisjoint(ss13)) - self.assertFalse(ss13.isdisjoint(ss1)) - self.assertFalse(ss3.isdisjoint(ss13)) - self.assertFalse(ss13.isdisjoint(ss3)) + assert not ss1.isdisjoint(ss13) + assert not ss13.isdisjoint(ss1) + assert not ss3.isdisjoint(ss13) + assert not ss13.isdisjoint(ss3) def test_issubset(self): s12 = set([1, 2]) @@ -118,13 +121,13 @@ def test_issubset(self): ss13 = sortedset([1, 3]) ss3 = sortedset([3]) - self.assertTrue(ss1.issubset(s12)) - self.assertTrue(ss1.issubset(ss13)) + assert ss1.issubset(s12) + assert ss1.issubset(ss13) - self.assertFalse(ss1.issubset(ss3)) - self.assertFalse(ss13.issubset(ss3)) - self.assertFalse(ss13.issubset(ss1)) - self.assertFalse(ss13.issubset(s12)) + assert not ss1.issubset(ss3) + assert not ss13.issubset(ss3) + assert not ss13.issubset(ss1) + assert not ss13.issubset(s12) def test_issuperset(self): s12 = set([1, 2]) @@ -132,253 +135,253 @@ def test_issuperset(self): ss13 = sortedset([1, 3]) ss3 = sortedset([3]) - self.assertTrue(s12.issuperset(ss1)) - self.assertTrue(ss13.issuperset(ss3)) - self.assertTrue(ss13.issuperset(ss13)) + assert s12.issuperset(ss1) + assert ss13.issuperset(ss3) + assert ss13.issuperset(ss13) - self.assertFalse(s12.issuperset(ss13)) - self.assertFalse(ss1.issuperset(ss3)) - self.assertFalse(ss1.issuperset(ss13)) + assert not s12.issuperset(ss13) + assert not ss1.issuperset(ss3) + assert not ss1.issuperset(ss13) def test_union(self): s1 = set([1]) ss12 = sortedset([1, 2]) ss23 = sortedset([2, 3]) - self.assertEqual(sortedset().union(s1), sortedset([1])) - self.assertEqual(ss12.union(s1), sortedset([1, 2])) - self.assertEqual(ss12.union(ss23), sortedset([1, 2, 3])) - self.assertEqual(ss23.union(ss12), sortedset([1, 2, 3])) - self.assertEqual(ss23.union(s1), sortedset([1, 2, 3])) + assert sortedset().union(s1) == sortedset([1]) + assert ss12.union(s1) == sortedset([1, 2]) + assert ss12.union(ss23) == sortedset([1, 2, 3]) + assert ss23.union(ss12) == sortedset([1, 2, 3]) + assert ss23.union(s1) == sortedset([1, 2, 3]) def test_intersection(self): s12 = set([1, 2]) ss23 = sortedset([2, 3]) - self.assertEqual(s12.intersection(ss23), set([2])) - self.assertEqual(ss23.intersection(s12), sortedset([2])) - self.assertEqual(ss23.intersection(s12, [2], (2,)), sortedset([2])) - self.assertEqual(ss23.intersection(s12, [900], (2,)), sortedset()) + assert s12.intersection(ss23) == set([2]) + assert ss23.intersection(s12) == sortedset([2]) + assert ss23.intersection(s12, [2], (2,)) == sortedset([2]) + assert ss23.intersection(s12, [900], (2,)) == sortedset() def test_difference(self): s1 = set([1]) ss12 = sortedset([1, 2]) ss23 = sortedset([2, 3]) - self.assertEqual(sortedset().difference(s1), sortedset()) - self.assertEqual(ss12.difference(s1), sortedset([2])) - self.assertEqual(ss12.difference(ss23), sortedset([1])) - self.assertEqual(ss23.difference(ss12), sortedset([3])) - self.assertEqual(ss23.difference(s1), sortedset([2, 3])) + assert sortedset().difference(s1) == sortedset() + assert ss12.difference(s1) == sortedset([2]) + assert ss12.difference(ss23) == sortedset([1]) + assert ss23.difference(ss12) == sortedset([3]) + assert ss23.difference(s1) == sortedset([2, 3]) def test_symmetric_difference(self): s = set([1, 3, 5]) ss = sortedset([2, 3, 4]) ss2 = sortedset([5, 6, 7]) - self.assertEqual(ss.symmetric_difference(s), sortedset([1, 2, 4, 5])) - self.assertFalse(ss.symmetric_difference(ss)) - self.assertEqual(ss.symmetric_difference(s), sortedset([1, 2, 4, 5])) - self.assertEqual(ss2.symmetric_difference(ss), sortedset([2, 3, 4, 5, 6, 7])) + assert ss.symmetric_difference(s) == sortedset([1, 2, 4, 5]) + assert not ss.symmetric_difference(ss) + assert ss.symmetric_difference(s) == sortedset([1, 2, 4, 5]) + assert ss2.symmetric_difference(ss) == sortedset([2, 3, 4, 5, 6, 7]) def test_pop(self): ss = sortedset([2, 1]) - self.assertEqual(ss.pop(), 2) - self.assertEqual(ss.pop(), 1) - try: + assert ss.pop() == 2 + assert ss.pop() == 1 + with pytest.raises((KeyError, IndexError)): ss.pop() - self.fail("Error not thrown") - except (KeyError, IndexError) as e: - pass + def test_remove(self): ss = sortedset([2, 1]) - self.assertEqual(len(ss), 2) - self.assertRaises(KeyError, ss.remove, 3) - self.assertEqual(len(ss), 2) + assert len(ss) == 2 + with pytest.raises(KeyError): + ss.remove(3) + assert len(ss) == 2 ss.remove(1) - self.assertEqual(len(ss), 1) + assert len(ss) == 1 ss.remove(2) - self.assertFalse(ss) - self.assertRaises(KeyError, ss.remove, 2) - self.assertFalse(ss) + assert not ss + with pytest.raises(KeyError): + ss.remove(2) + assert not ss def test_getitem(self): ss = sortedset(range(3)) for i in range(len(ss)): - self.assertEqual(ss[i], i) - with self.assertRaises(IndexError): + assert ss[i] == i + with pytest.raises(IndexError): ss[len(ss)] def test_delitem(self): expected = [1,2,3,4] ss = sortedset(expected) for i in range(len(ss)): - self.assertListEqual(list(ss), expected[i:]) + assertListEqual(list(ss), expected[i:]) del ss[0] - with self.assertRaises(IndexError): + with pytest.raises(IndexError): ss[0] def test_delslice(self): expected = [1, 2, 3, 4, 5] ss = sortedset(expected) del ss[1:3] - self.assertListEqual(list(ss), [1, 4, 5]) + assertListEqual(list(ss), [1, 4, 5]) del ss[-1:] - self.assertListEqual(list(ss), [1, 4]) + assertListEqual(list(ss), [1, 4]) del ss[1:] - self.assertListEqual(list(ss), [1]) + assertListEqual(list(ss), [1]) del ss[:] - self.assertFalse(ss) - with self.assertRaises(IndexError): + assert not ss + with pytest.raises(IndexError): del ss[0] def test_reversed(self): expected = range(10) - self.assertListEqual(list(reversed(sortedset(expected))), list(reversed(expected))) + assertListEqual(list(reversed(sortedset(expected))), list(reversed(expected))) def test_operators(self): ss1 = sortedset([1]) ss12 = sortedset([1, 2]) # __ne__ - self.assertFalse(ss12 != ss12) - self.assertFalse(ss12 != sortedset([1, 2])) - self.assertTrue(ss12 != sortedset()) + assert not ss12 != ss12 + assert not ss12 != sortedset([1, 2]) + assert ss12 != sortedset() # __le__ - self.assertTrue(ss1 <= ss12) - self.assertTrue(ss12 <= ss12) - self.assertFalse(ss12 <= ss1) + assert ss1 <= ss12 + assert ss12 <= ss12 + assert not ss12 <= ss1 # __lt__ - self.assertTrue(ss1 < ss12) - self.assertFalse(ss12 < ss12) - self.assertFalse(ss12 < ss1) + assert ss1 < ss12 + assert not ss12 < ss12 + assert not ss12 < ss1 # __ge__ - self.assertFalse(ss1 >= ss12) - self.assertTrue(ss12 >= ss12) - self.assertTrue(ss12 >= ss1) + assert not ss1 >= ss12 + assert ss12 >= ss12 + assert ss12 >= ss1 # __gt__ - self.assertFalse(ss1 > ss12) - self.assertFalse(ss12 > ss12) - self.assertTrue(ss12 > ss1) + assert not ss1 > ss12 + assert not ss12 > ss12 + assert ss12 > ss1 # __and__ - self.assertEqual(ss1 & ss12, ss1) - self.assertEqual(ss12 & ss12, ss12) - self.assertEqual(ss12 & set(), sortedset()) + assert ss1 & ss12 == ss1 + assert ss12 & ss12 == ss12 + assert ss12 & set() == sortedset() # __iand__ tmp = sortedset(ss12) tmp &= ss1 - self.assertEqual(tmp, ss1) + assert tmp == ss1 tmp = sortedset(ss1) tmp &= ss12 - self.assertEqual(tmp, ss1) + assert tmp == ss1 tmp = sortedset(ss12) tmp &= ss12 - self.assertEqual(tmp, ss12) + assert tmp == ss12 tmp = sortedset(ss12) tmp &= set() - self.assertEqual(tmp, sortedset()) + assert tmp == sortedset() # __rand__ - self.assertEqual(set([1]) & ss12, ss1) + assert set([1]) & ss12 == ss1 # __or__ - self.assertEqual(ss1 | ss12, ss12) - self.assertEqual(ss12 | ss12, ss12) - self.assertEqual(ss12 | set(), ss12) - self.assertEqual(sortedset() | ss1 | ss12, ss12) + assert ss1 | ss12 == ss12 + assert ss12 | ss12 == ss12 + assert ss12 | set() == ss12 + assert sortedset() | ss1 | ss12 == ss12 # __ior__ tmp = sortedset(ss1) tmp |= ss12 - self.assertEqual(tmp, ss12) + assert tmp == ss12 tmp = sortedset(ss12) tmp |= ss12 - self.assertEqual(tmp, ss12) + assert tmp == ss12 tmp = sortedset(ss12) tmp |= set() - self.assertEqual(tmp, ss12) + assert tmp == ss12 tmp = sortedset() tmp |= ss1 tmp |= ss12 - self.assertEqual(tmp, ss12) + assert tmp == ss12 # __ror__ - self.assertEqual(set([1]) | ss12, ss12) + assert set([1]) | ss12 == ss12 # __sub__ - self.assertEqual(ss1 - ss12, set()) - self.assertEqual(ss12 - ss12, set()) - self.assertEqual(ss12 - set(), ss12) - self.assertEqual(ss12 - ss1, sortedset([2])) + assert ss1 - ss12 == set() + assert ss12 - ss12 == set() + assert ss12 - set() == ss12 + assert ss12 - ss1 == sortedset([2]) # __isub__ tmp = sortedset(ss1) tmp -= ss12 - self.assertEqual(tmp, set()) + assert tmp == set() tmp = sortedset(ss12) tmp -= ss12 - self.assertEqual(tmp, set()) + assert tmp == set() tmp = sortedset(ss12) tmp -= set() - self.assertEqual(tmp, ss12) + assert tmp == ss12 tmp = sortedset(ss12) tmp -= ss1 - self.assertEqual(tmp, sortedset([2])) + assert tmp == sortedset([2]) # __rsub__ - self.assertEqual(set((1,2,3)) - ss12, set((3,))) + assert set((1,2,3)) - ss12 == set((3,)) # __xor__ - self.assertEqual(ss1 ^ ss12, set([2])) - self.assertEqual(ss12 ^ ss1, set([2])) - self.assertEqual(ss12 ^ ss12, set()) - self.assertEqual(ss12 ^ set(), ss12) + assert ss1 ^ ss12 == set([2]) + assert ss12 ^ ss1 == set([2]) + assert ss12 ^ ss12 == set() + assert ss12 ^ set() == ss12 # __ixor__ tmp = sortedset(ss1) tmp ^= ss12 - self.assertEqual(tmp, set([2])) + assert tmp == set([2]) tmp = sortedset(ss12) tmp ^= ss1 - self.assertEqual(tmp, set([2])) + assert tmp == set([2]) tmp = sortedset(ss12) tmp ^= ss12 - self.assertEqual(tmp, set()) + assert tmp == set() tmp = sortedset(ss12) tmp ^= set() - self.assertEqual(tmp, ss12) + assert tmp == ss12 # __rxor__ - self.assertEqual(set([1, 2]) ^ ss1, (set([2]))) + assert set([1, 2]) ^ ss1 == (set([2])) def test_reduce_pickle(self): ss = sortedset((4,3,2,1)) import pickle s = pickle.dumps(ss) - self.assertEqual(pickle.loads(s), ss) + assert pickle.loads(s) == ss def _test_uncomparable_types(self, items): for perm in permutations(items): ss = sortedset(perm) s = set(perm) - self.assertEqual(s, ss) - self.assertEqual(ss, ss.union(s)) + assert s == ss + assert ss == ss.union(s) for x in range(len(ss)): subset = set(s) for _ in range(x): subset.pop() - self.assertEqual(ss.difference(subset), s.difference(subset)) - self.assertEqual(ss.intersection(subset), s.intersection(subset)) + assert ss.difference(subset) == s.difference(subset) + assert ss.intersection(subset) == s.intersection(subset) for x in ss: - self.assertIn(x, ss) + assert x in ss ss.remove(x) - self.assertNotIn(x, ss) + assert x not in ss def test_uncomparable_types_with_tuples(self): # PYTHON-1087 - make set handle uncomparable types diff --git a/tests/unit/test_tablets.py b/tests/unit/test_tablets.py index 3bbba06918..5e640fa4c9 100644 --- a/tests/unit/test_tablets.py +++ b/tests/unit/test_tablets.py @@ -4,11 +4,11 @@ class TabletsTest(unittest.TestCase): def compare_ranges(self, tablets, ranges): - self.assertEqual(len(tablets), len(ranges)) + assert len(tablets) == len(ranges) for idx, tablet in enumerate(tablets): - self.assertEqual(tablet.first_token, ranges[idx][0], "First token is not correct in tablet: {}".format(tablet)) - self.assertEqual(tablet.last_token, ranges[idx][1], "Last token is not correct in tablet: {}".format(tablet)) + assert tablet.first_token == ranges[idx][0], "First token is not correct in tablet: {}".format(tablet) + assert tablet.last_token == ranges[idx][1], "Last token is not correct in tablet: {}".format(tablet) def test_add_tablet_to_empty_tablets(self): tablets = Tablets({("test_ks", "test_tb"): []}) diff --git a/tests/unit/test_time_util.py b/tests/unit/test_time_util.py index 2605992d1c..6c2c46d180 100644 --- a/tests/unit/test_time_util.py +++ b/tests/unit/test_time_util.py @@ -20,20 +20,21 @@ import datetime import time import uuid +import pytest class TimeUtilTest(unittest.TestCase): def test_datetime_from_timestamp(self): - self.assertEqual(util.datetime_from_timestamp(0), datetime.datetime(1970, 1, 1)) + assert util.datetime_from_timestamp(0) == datetime.datetime(1970, 1, 1) # large negative; test PYTHON-110 workaround for windows - self.assertEqual(util.datetime_from_timestamp(-62135596800), datetime.datetime(1, 1, 1)) - self.assertEqual(util.datetime_from_timestamp(-62135596199), datetime.datetime(1, 1, 1, 0, 10, 1)) + assert util.datetime_from_timestamp(-62135596800) == datetime.datetime(1, 1, 1) + assert util.datetime_from_timestamp(-62135596199) == datetime.datetime(1, 1, 1, 0, 10, 1) - self.assertEqual(util.datetime_from_timestamp(253402300799), datetime.datetime(9999, 12, 31, 23, 59, 59)) + assert util.datetime_from_timestamp(253402300799) == datetime.datetime(9999, 12, 31, 23, 59, 59) - self.assertEqual(util.datetime_from_timestamp(0.123456), datetime.datetime(1970, 1, 1, 0, 0, 0, 123456)) + assert util.datetime_from_timestamp(0.123456) == datetime.datetime(1970, 1, 1, 0, 0, 0, 123456) - self.assertEqual(util.datetime_from_timestamp(2177403010.123456), datetime.datetime(2038, 12, 31, 10, 10, 10, 123456)) + assert util.datetime_from_timestamp(2177403010.123456) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123456) def test_times_from_uuid1(self): node = uuid.getnode() @@ -41,11 +42,11 @@ def test_times_from_uuid1(self): u = uuid.uuid1(node, 0) t = util.unix_time_from_uuid1(u) - self.assertAlmostEqual(now, t, 2) + assert now == pytest.approx(t, abs=1e-2) dt = util.datetime_from_uuid1(u) t = calendar.timegm(dt.timetuple()) + dt.microsecond / 1e6 - self.assertAlmostEqual(now, t, 2) + assert now == pytest.approx(t, abs=1e-2) def test_uuid_from_time(self): t = time.time() @@ -54,42 +55,42 @@ def test_uuid_from_time(self): u = util.uuid_from_time(t, node, seq) # using AlmostEqual because time precision is different for # some platforms - self.assertAlmostEqual(util.unix_time_from_uuid1(u), t, 4) - self.assertEqual(u.node, node) - self.assertEqual(u.clock_seq, seq) + assert util.unix_time_from_uuid1(u) == pytest.approx(t, abs=1e-4) + assert u.node == node + assert u.clock_seq == seq # random node u1 = util.uuid_from_time(t, clock_seq=seq) u2 = util.uuid_from_time(t, clock_seq=seq) - self.assertAlmostEqual(util.unix_time_from_uuid1(u1), t, 4) - self.assertAlmostEqual(util.unix_time_from_uuid1(u2), t, 4) - self.assertEqual(u.clock_seq, seq) + assert util.unix_time_from_uuid1(u1) == pytest.approx(t, abs=1e-4) + assert util.unix_time_from_uuid1(u2) == pytest.approx(t, abs=1e-4) + assert u.clock_seq == seq # not impossible, but we shouldn't get the same value twice - self.assertNotEqual(u1.node, u2.node) + assert u1.node != u2.node # random seq u1 = util.uuid_from_time(t, node=node) u2 = util.uuid_from_time(t, node=node) - self.assertAlmostEqual(util.unix_time_from_uuid1(u1), t, 4) - self.assertAlmostEqual(util.unix_time_from_uuid1(u2), t, 4) - self.assertEqual(u.node, node) + assert util.unix_time_from_uuid1(u1) == pytest.approx(t, abs=1e-4) + assert util.unix_time_from_uuid1(u2) == pytest.approx(t, abs=1e-4) + assert u.node == node # not impossible, but we shouldn't get the same value twice - self.assertNotEqual(u1.clock_seq, u2.clock_seq) + assert u1.clock_seq != u2.clock_seq # node too large - with self.assertRaises(ValueError): + with pytest.raises(ValueError): u = util.uuid_from_time(t, node=2 ** 48) # clock_seq too large - with self.assertRaises(ValueError): + with pytest.raises(ValueError): u = util.uuid_from_time(t, clock_seq=0x4000) # construct from datetime dt = util.datetime_from_timestamp(t) u = util.uuid_from_time(dt, node, seq) - self.assertAlmostEqual(util.unix_time_from_uuid1(u), t, 4) - self.assertEqual(u.node, node) - self.assertEqual(u.clock_seq, seq) + assert util.unix_time_from_uuid1(u) == pytest.approx(t, abs=1e-4) + assert u.node == node + assert u.clock_seq == seq # 0 1 2 3 # 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -106,7 +107,7 @@ def test_min_uuid(self): u = util.min_uuid_from_time(0) # cassandra does a signed comparison of the remaining bytes for i in range(8, 16): - self.assertEqual(marshal.int8_unpack(u.bytes[i:i + 1]), -128) + assert marshal.int8_unpack(u.bytes[i:i + 1]) == -128 def test_max_uuid(self): u = util.max_uuid_from_time(0) @@ -114,6 +115,6 @@ def test_max_uuid(self): # the first non-time byte has the variant in it # This byte is always negative, but should be the smallest negative # number with high-order bits '10' - self.assertEqual(marshal.int8_unpack(u.bytes[8:9]), -65) + assert marshal.int8_unpack(u.bytes[8:9]) == -65 for i in range(9, 16): - self.assertEqual(marshal.int8_unpack(u.bytes[i:i + 1]), 127) + assert marshal.int8_unpack(u.bytes[i:i + 1]) == 127 diff --git a/tests/unit/test_timestamps.py b/tests/unit/test_timestamps.py index 151c004c90..8ef747d515 100644 --- a/tests/unit/test_timestamps.py +++ b/tests/unit/test_timestamps.py @@ -16,7 +16,9 @@ from unittest import mock from cassandra import timestamps +from tests.util import assertRegex from threading import Thread, Lock +import pytest class _TimestampTestMixin(object): @@ -42,10 +44,10 @@ def _call_and_check_results(self, for expected in expected_timestamps: actual = tsg() if expected is not None: - self.assertEqual(actual, expected) + assert actual == expected # assert we patched timestamps.time.time correctly - with self.assertRaises(StopIteration): + with pytest.raises(StopIteration): tsg() @@ -102,9 +104,9 @@ def setUp(self): def assertLastCallArgRegex(self, call, pattern): last_warn_args, last_warn_kwargs = call - self.assertEqual(len(last_warn_args), 1) - self.assertEqual(len(last_warn_kwargs), 0) - self.assertRegex(last_warn_args[0], pattern) + assert len(last_warn_args) == 1 + assert len(last_warn_kwargs) == 0 + assertRegex(last_warn_args[0], pattern) def test_basic_log_content(self): """ @@ -124,10 +126,10 @@ def test_basic_log_content(self): tsg._last_warn = 12 tsg._next_timestamp(20, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + assert len(self.patched_timestamp_log.warning.call_args_list) == 0 tsg._next_timestamp(16, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + assert len(self.patched_timestamp_log.warning.call_args_list) == 1 self.assertLastCallArgRegex( self.patched_timestamp_log.warning.call_args, r'Clock skew detected:.*\b16\b.*\b4\b.*\b20\b' @@ -147,7 +149,7 @@ def test_disable_logging(self): no_warn_tsg.last = 100 no_warn_tsg._next_timestamp(99, no_warn_tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + assert len(self.patched_timestamp_log.warning.call_args_list) == 0 def test_warning_threshold_respected_no_logging(self): """ @@ -164,7 +166,7 @@ def test_warning_threshold_respected_no_logging(self): ) tsg.last, tsg._last_warn = 100, 97 tsg._next_timestamp(98, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 0) + assert len(self.patched_timestamp_log.warning.call_args_list) == 0 def test_warning_threshold_respected_logs(self): """ @@ -182,7 +184,7 @@ def test_warning_threshold_respected_logs(self): ) tsg.last, tsg._last_warn = 100, 97 tsg._next_timestamp(98, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + assert len(self.patched_timestamp_log.warning.call_args_list) == 1 def test_warning_interval_respected_no_logging(self): """ @@ -200,10 +202,10 @@ def test_warning_interval_respected_no_logging(self): ) tsg.last = 100 tsg._next_timestamp(70, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + assert len(self.patched_timestamp_log.warning.call_args_list) == 1 tsg._next_timestamp(71, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + assert len(self.patched_timestamp_log.warning.call_args_list) == 1 def test_warning_interval_respected_logs(self): """ @@ -222,10 +224,10 @@ def test_warning_interval_respected_logs(self): ) tsg.last = 100 tsg._next_timestamp(70, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 1) + assert len(self.patched_timestamp_log.warning.call_args_list) == 1 tsg._next_timestamp(72, tsg.last) - self.assertEqual(len(self.patched_timestamp_log.warning.call_args_list), 2) + assert len(self.patched_timestamp_log.warning.call_args_list) == 2 class TestTimestampGeneratorMultipleThreads(unittest.TestCase): @@ -266,6 +268,6 @@ def request_time(): for t in threads: t.join() - self.assertEqual(len(generated_timestamps), num_threads * timestamp_to_generate) + assert len(generated_timestamps) == num_threads * timestamp_to_generate for i, timestamp in enumerate(sorted(generated_timestamps)): - self.assertEqual(int(i + 1e6), timestamp) + assert int(i + 1e6) == timestamp diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index aba11d4ced..3390f6dbd6 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -45,6 +45,7 @@ datetime_from_timestamp ) from tests.unit.util import check_sequence_consistency +import pytest class TypeTests(unittest.TestCase): @@ -54,83 +55,83 @@ def test_lookup_casstype_simple(self): Ensure lookup_casstype_simple returns the correct classes """ - self.assertEqual(lookup_casstype_simple('AsciiType'), cassandra.cqltypes.AsciiType) - self.assertEqual(lookup_casstype_simple('LongType'), cassandra.cqltypes.LongType) - self.assertEqual(lookup_casstype_simple('BytesType'), cassandra.cqltypes.BytesType) - self.assertEqual(lookup_casstype_simple('BooleanType'), cassandra.cqltypes.BooleanType) - self.assertEqual(lookup_casstype_simple('CounterColumnType'), cassandra.cqltypes.CounterColumnType) - self.assertEqual(lookup_casstype_simple('DecimalType'), cassandra.cqltypes.DecimalType) - self.assertEqual(lookup_casstype_simple('DoubleType'), cassandra.cqltypes.DoubleType) - self.assertEqual(lookup_casstype_simple('FloatType'), cassandra.cqltypes.FloatType) - self.assertEqual(lookup_casstype_simple('InetAddressType'), cassandra.cqltypes.InetAddressType) - self.assertEqual(lookup_casstype_simple('Int32Type'), cassandra.cqltypes.Int32Type) - self.assertEqual(lookup_casstype_simple('UTF8Type'), cassandra.cqltypes.UTF8Type) - self.assertEqual(lookup_casstype_simple('DateType'), cassandra.cqltypes.DateType) - self.assertEqual(lookup_casstype_simple('SimpleDateType'), cassandra.cqltypes.SimpleDateType) - self.assertEqual(lookup_casstype_simple('ByteType'), cassandra.cqltypes.ByteType) - self.assertEqual(lookup_casstype_simple('ShortType'), cassandra.cqltypes.ShortType) - self.assertEqual(lookup_casstype_simple('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) - self.assertEqual(lookup_casstype_simple('TimeType'), cassandra.cqltypes.TimeType) - self.assertEqual(lookup_casstype_simple('UUIDType'), cassandra.cqltypes.UUIDType) - self.assertEqual(lookup_casstype_simple('IntegerType'), cassandra.cqltypes.IntegerType) - self.assertEqual(lookup_casstype_simple('MapType'), cassandra.cqltypes.MapType) - self.assertEqual(lookup_casstype_simple('ListType'), cassandra.cqltypes.ListType) - self.assertEqual(lookup_casstype_simple('SetType'), cassandra.cqltypes.SetType) - self.assertEqual(lookup_casstype_simple('CompositeType'), cassandra.cqltypes.CompositeType) - self.assertEqual(lookup_casstype_simple('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) - self.assertEqual(lookup_casstype_simple('ReversedType'), cassandra.cqltypes.ReversedType) - self.assertEqual(lookup_casstype_simple('DurationType'), cassandra.cqltypes.DurationType) - self.assertEqual(lookup_casstype_simple('DateRangeType'), cassandra.cqltypes.DateRangeType) - - self.assertEqual(str(lookup_casstype_simple('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) + assert lookup_casstype_simple('AsciiType') == cassandra.cqltypes.AsciiType + assert lookup_casstype_simple('LongType') == cassandra.cqltypes.LongType + assert lookup_casstype_simple('BytesType') == cassandra.cqltypes.BytesType + assert lookup_casstype_simple('BooleanType') == cassandra.cqltypes.BooleanType + assert lookup_casstype_simple('CounterColumnType') == cassandra.cqltypes.CounterColumnType + assert lookup_casstype_simple('DecimalType') == cassandra.cqltypes.DecimalType + assert lookup_casstype_simple('DoubleType') == cassandra.cqltypes.DoubleType + assert lookup_casstype_simple('FloatType') == cassandra.cqltypes.FloatType + assert lookup_casstype_simple('InetAddressType') == cassandra.cqltypes.InetAddressType + assert lookup_casstype_simple('Int32Type') == cassandra.cqltypes.Int32Type + assert lookup_casstype_simple('UTF8Type') == cassandra.cqltypes.UTF8Type + assert lookup_casstype_simple('DateType') == cassandra.cqltypes.DateType + assert lookup_casstype_simple('SimpleDateType') == cassandra.cqltypes.SimpleDateType + assert lookup_casstype_simple('ByteType') == cassandra.cqltypes.ByteType + assert lookup_casstype_simple('ShortType') == cassandra.cqltypes.ShortType + assert lookup_casstype_simple('TimeUUIDType') == cassandra.cqltypes.TimeUUIDType + assert lookup_casstype_simple('TimeType') == cassandra.cqltypes.TimeType + assert lookup_casstype_simple('UUIDType') == cassandra.cqltypes.UUIDType + assert lookup_casstype_simple('IntegerType') == cassandra.cqltypes.IntegerType + assert lookup_casstype_simple('MapType') == cassandra.cqltypes.MapType + assert lookup_casstype_simple('ListType') == cassandra.cqltypes.ListType + assert lookup_casstype_simple('SetType') == cassandra.cqltypes.SetType + assert lookup_casstype_simple('CompositeType') == cassandra.cqltypes.CompositeType + assert lookup_casstype_simple('ColumnToCollectionType') == cassandra.cqltypes.ColumnToCollectionType + assert lookup_casstype_simple('ReversedType') == cassandra.cqltypes.ReversedType + assert lookup_casstype_simple('DurationType') == cassandra.cqltypes.DurationType + assert lookup_casstype_simple('DateRangeType') == cassandra.cqltypes.DateRangeType + + assert str(lookup_casstype_simple('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) def test_lookup_casstype(self): """ Ensure lookup_casstype returns the correct classes """ - self.assertEqual(lookup_casstype('AsciiType'), cassandra.cqltypes.AsciiType) - self.assertEqual(lookup_casstype('LongType'), cassandra.cqltypes.LongType) - self.assertEqual(lookup_casstype('BytesType'), cassandra.cqltypes.BytesType) - self.assertEqual(lookup_casstype('BooleanType'), cassandra.cqltypes.BooleanType) - self.assertEqual(lookup_casstype('CounterColumnType'), cassandra.cqltypes.CounterColumnType) - self.assertEqual(lookup_casstype('DateType'), cassandra.cqltypes.DateType) - self.assertEqual(lookup_casstype('DecimalType'), cassandra.cqltypes.DecimalType) - self.assertEqual(lookup_casstype('DoubleType'), cassandra.cqltypes.DoubleType) - self.assertEqual(lookup_casstype('FloatType'), cassandra.cqltypes.FloatType) - self.assertEqual(lookup_casstype('InetAddressType'), cassandra.cqltypes.InetAddressType) - self.assertEqual(lookup_casstype('Int32Type'), cassandra.cqltypes.Int32Type) - self.assertEqual(lookup_casstype('UTF8Type'), cassandra.cqltypes.UTF8Type) - self.assertEqual(lookup_casstype('DateType'), cassandra.cqltypes.DateType) - self.assertEqual(lookup_casstype('TimeType'), cassandra.cqltypes.TimeType) - self.assertEqual(lookup_casstype('ByteType'), cassandra.cqltypes.ByteType) - self.assertEqual(lookup_casstype('ShortType'), cassandra.cqltypes.ShortType) - self.assertEqual(lookup_casstype('TimeUUIDType'), cassandra.cqltypes.TimeUUIDType) - self.assertEqual(lookup_casstype('UUIDType'), cassandra.cqltypes.UUIDType) - self.assertEqual(lookup_casstype('IntegerType'), cassandra.cqltypes.IntegerType) - self.assertEqual(lookup_casstype('MapType'), cassandra.cqltypes.MapType) - self.assertEqual(lookup_casstype('ListType'), cassandra.cqltypes.ListType) - self.assertEqual(lookup_casstype('SetType'), cassandra.cqltypes.SetType) - self.assertEqual(lookup_casstype('CompositeType'), cassandra.cqltypes.CompositeType) - self.assertEqual(lookup_casstype('ColumnToCollectionType'), cassandra.cqltypes.ColumnToCollectionType) - self.assertEqual(lookup_casstype('ReversedType'), cassandra.cqltypes.ReversedType) - self.assertEqual(lookup_casstype('DurationType'), cassandra.cqltypes.DurationType) - self.assertEqual(lookup_casstype('DateRangeType'), cassandra.cqltypes.DateRangeType) - - self.assertEqual(str(lookup_casstype('unknown')), str(cassandra.cqltypes.mkUnrecognizedType('unknown'))) - - self.assertRaises(ValueError, lookup_casstype, 'AsciiType~') + assert lookup_casstype('AsciiType') == cassandra.cqltypes.AsciiType + assert lookup_casstype('LongType') == cassandra.cqltypes.LongType + assert lookup_casstype('BytesType') == cassandra.cqltypes.BytesType + assert lookup_casstype('BooleanType') == cassandra.cqltypes.BooleanType + assert lookup_casstype('CounterColumnType') == cassandra.cqltypes.CounterColumnType + assert lookup_casstype('DateType') == cassandra.cqltypes.DateType + assert lookup_casstype('DecimalType') == cassandra.cqltypes.DecimalType + assert lookup_casstype('DoubleType') == cassandra.cqltypes.DoubleType + assert lookup_casstype('FloatType') == cassandra.cqltypes.FloatType + assert lookup_casstype('InetAddressType') == cassandra.cqltypes.InetAddressType + assert lookup_casstype('Int32Type') == cassandra.cqltypes.Int32Type + assert lookup_casstype('UTF8Type') == cassandra.cqltypes.UTF8Type + assert lookup_casstype('DateType') == cassandra.cqltypes.DateType + assert lookup_casstype('TimeType') == cassandra.cqltypes.TimeType + assert lookup_casstype('ByteType') == cassandra.cqltypes.ByteType + assert lookup_casstype('ShortType') == cassandra.cqltypes.ShortType + assert lookup_casstype('TimeUUIDType') == cassandra.cqltypes.TimeUUIDType + assert lookup_casstype('UUIDType') == cassandra.cqltypes.UUIDType + assert lookup_casstype('IntegerType') == cassandra.cqltypes.IntegerType + assert lookup_casstype('MapType') == cassandra.cqltypes.MapType + assert lookup_casstype('ListType') == cassandra.cqltypes.ListType + assert lookup_casstype('SetType') == cassandra.cqltypes.SetType + assert lookup_casstype('CompositeType') == cassandra.cqltypes.CompositeType + assert lookup_casstype('ColumnToCollectionType') == cassandra.cqltypes.ColumnToCollectionType + assert lookup_casstype('ReversedType') == cassandra.cqltypes.ReversedType + assert lookup_casstype('DurationType') == cassandra.cqltypes.DurationType + assert lookup_casstype('DateRangeType') == cassandra.cqltypes.DateRangeType + + assert str(lookup_casstype('unknown')) == str(cassandra.cqltypes.mkUnrecognizedType('unknown')) + + with pytest.raises(ValueError): + lookup_casstype('AsciiType~') def test_casstype_parameterized(self): - self.assertEqual(LongType.cass_parameterized_type_with(()), 'LongType') - self.assertEqual(LongType.cass_parameterized_type_with((), full=True), 'org.apache.cassandra.db.marshal.LongType') - self.assertEqual(SetType.cass_parameterized_type_with([DecimalType], full=True), 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)') + assert LongType.cass_parameterized_type_with(()) == 'LongType' + assert LongType.cass_parameterized_type_with((), full=True) == 'org.apache.cassandra.db.marshal.LongType' + assert SetType.cass_parameterized_type_with([DecimalType], full=True) == 'org.apache.cassandra.db.marshal.SetType(org.apache.cassandra.db.marshal.DecimalType)' - self.assertEqual(LongType.cql_parameterized_type(), 'bigint') + assert LongType.cql_parameterized_type() == 'bigint' subtypes = (cassandra.cqltypes.UTF8Type, cassandra.cqltypes.UTF8Type) - self.assertEqual('map', - cassandra.cqltypes.MapType.apply_parameters(subtypes).cql_parameterized_type()) + assert 'map' == cassandra.cqltypes.MapType.apply_parameters(subtypes).cql_parameterized_type() def test_datetype_from_string(self): # Ensure all formats can be parsed, without exception @@ -143,18 +144,18 @@ def test_cql_typename(self): Smoke test cql_typename """ - self.assertEqual(cql_typename('DateType'), 'timestamp') - self.assertEqual(cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)'), 'list') + assert cql_typename('DateType') == 'timestamp' + assert cql_typename('org.apache.cassandra.db.marshal.ListType(IntegerType)') == 'list' def test_named_tuple_colname_substitution(self): colnames = ("func(abc)", "[applied]", "func(func(abc))", "foo_bar", "foo_bar_") rows = [(1, 2, 3, 4, 5)] result = named_tuple_factory(colnames, rows)[0] - self.assertEqual(result[0], result.func_abc) - self.assertEqual(result[1], result.applied) - self.assertEqual(result[2], result.func_func_abc) - self.assertEqual(result[3], result.foo_bar) - self.assertEqual(result[4], result.foo_bar_) + assert result[0] == result.func_abc + assert result[1] == result.applied + assert result[2] == result.func_func_abc + assert result[3] == result.foo_bar + assert result[4] == result.foo_bar_ def test_parse_casstype_args(self): class FooType(CassandraType): @@ -178,36 +179,36 @@ class BarType(FooType): '7a6970:org.apache.cassandra.db.marshal.UTF8Type', ')'))) - self.assertEqual(FooType, ctype.__class__) + assert FooType == ctype.__class__ - self.assertEqual(UTF8Type, ctype.subtypes[0]) + assert UTF8Type == ctype.subtypes[0] # middle subtype should be a BarType instance with its own subtypes and names - self.assertIsInstance(ctype.subtypes[1], BarType) - self.assertEqual([UTF8Type], ctype.subtypes[1].subtypes) - self.assertEqual([b"address"], ctype.subtypes[1].names) + assert isinstance(ctype.subtypes[1], BarType) + assert [UTF8Type] == ctype.subtypes[1].subtypes + assert [b"address"] == ctype.subtypes[1].names - self.assertEqual(UTF8Type, ctype.subtypes[2]) - self.assertEqual([b'city', None, b'zip'], ctype.names) + assert UTF8Type == ctype.subtypes[2] + assert [b'city', None, b'zip'] == ctype.names def test_parse_casstype_vector(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)") - self.assertTrue(issubclass(ctype, VectorType)) - self.assertEqual(3, ctype.vector_size) - self.assertEqual(FloatType, ctype.subtype) + assert issubclass(ctype, VectorType) + assert 3 == ctype.vector_size + assert FloatType == ctype.subtype def test_parse_casstype_vector_of_vectors(self): inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) - self.assertTrue(issubclass(ctype, VectorType)) - self.assertEqual(3, ctype.vector_size) + assert issubclass(ctype, VectorType) + assert 3 == ctype.vector_size sub_ctype = ctype.subtype - self.assertTrue(issubclass(sub_ctype, VectorType)) - self.assertEqual(4, sub_ctype.vector_size) - self.assertEqual(FloatType, sub_ctype.subtype) + assert issubclass(sub_ctype, VectorType) + assert 4 == sub_ctype.vector_size + assert FloatType == sub_ctype.subtype def test_empty_value(self): - self.assertEqual(str(EmptyValue()), 'EMPTY') + assert str(EmptyValue()) == 'EMPTY' def test_datetype(self): now_time_seconds = time.time() @@ -217,28 +218,28 @@ def test_datetype(self): now_timestamp = now_time_seconds * 1e3 # same results serialized - self.assertEqual(DateType.serialize(now_datetime, 0), DateType.serialize(now_timestamp, 0)) + assert DateType.serialize(now_datetime, 0) == DateType.serialize(now_timestamp, 0) # deserialize # epoc expected = 0 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None)) + assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime.fromtimestamp(expected, tz=datetime.timezone.utc).replace(tzinfo=None) # beyond 32b expected = 2 ** 33 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) + assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(2242, 3, 16, 12, 56, 32, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # less than epoc (PYTHON-119) expected = -770172256 - self.assertEqual(DateType.deserialize(int64_pack(1000 * expected), 0), datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) + assert DateType.deserialize(int64_pack(1000 * expected), 0) == datetime.datetime(1945, 8, 5, 23, 15, 44, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # work around rounding difference among Python versions (PYTHON-230) expected = 1424817268.274 - self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) + assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2015, 2, 24, 22, 34, 28, 274000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) # Large date overflow (PYTHON-452) expected = 2177403010.123 - self.assertEqual(DateType.deserialize(int64_pack(int(1000 * expected)), 0), datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None)) + assert DateType.deserialize(int64_pack(int(1000 * expected)), 0) == datetime.datetime(2038, 12, 31, 10, 10, 10, 123000, tzinfo=datetime.timezone.utc).replace(tzinfo=None) def test_collection_null_support(self): """ @@ -253,16 +254,10 @@ def test_collection_null_support(self): int32_pack(4) + # size of item2 int32_pack(42) # item2 ) - self.assertEqual( - [None, 42], - int_list.deserialize(value, 3) - ) + assert [None, 42] == int_list.deserialize(value, 3) set_list = SetType.apply_parameters([Int32Type]) - self.assertEqual( - {None, 42}, - set(set_list.deserialize(value, 3)) - ) + assert {None, 42} == set(set_list.deserialize(value, 3)) value = ( int32_pack(2) + # num items @@ -275,49 +270,47 @@ def test_collection_null_support(self): ) map_list = MapType.apply_parameters([Int32Type, Int32Type]) - self.assertEqual( - [(42, None), (None, 42)], - map_list.deserialize(value, 3)._items # OrderedMapSerializedKey - ) + + assert [(42, None), (None, 42)] == map_list.deserialize(value, 3)._items # OrderedMapSerializedKey def test_write_read_string(self): with tempfile.TemporaryFile() as f: value = u'test' write_string(f, value) f.seek(0) - self.assertEqual(read_string(f), value) + assert read_string(f) == value def test_write_read_longstring(self): with tempfile.TemporaryFile() as f: value = u'test' write_longstring(f, value) f.seek(0) - self.assertEqual(read_longstring(f), value) + assert read_longstring(f) == value def test_write_read_stringmap(self): with tempfile.TemporaryFile() as f: value = {'key': 'value'} write_stringmap(f, value) f.seek(0) - self.assertEqual(read_stringmap(f), value) + assert read_stringmap(f) == value def test_write_read_inet(self): with tempfile.TemporaryFile() as f: value = ('192.168.1.1', 9042) write_inet(f, value) f.seek(0) - self.assertEqual(read_inet(f), value) + assert read_inet(f) == value with tempfile.TemporaryFile() as f: value = ('2001:db8:0:f101::1', 9042) write_inet(f, value) f.seek(0) - self.assertEqual(read_inet(f), value) + assert read_inet(f) == value def test_cql_quote(self): - self.assertEqual(cql_quote(u'test'), "'test'") - self.assertEqual(cql_quote('test'), "'test'") - self.assertEqual(cql_quote(0), '0') + assert cql_quote(u'test') == "'test'" + assert cql_quote('test') == "'test'" + assert cql_quote(0) == '0' class VectorTests(unittest.TestCase): @@ -328,31 +321,31 @@ def _normalize_set(self, val): def _round_trip_compare_fn(self, first, second): if isinstance(first, float): - self.assertAlmostEqual(first, second, places=5) + assert first == pytest.approx(second, rel=1e-5) elif isinstance(first, list): - self.assertEqual(len(first), len(second)) + assert len(first) == len(second) for (felem, selem) in zip(first, second): self._round_trip_compare_fn(felem, selem) elif isinstance(first, set) or isinstance(first, frozenset): - self.assertEqual(len(first), len(second)) + assert len(first) == len(second) first_norm = self._normalize_set(first) second_norm = self._normalize_set(second) - self.assertEqual(first_norm, second_norm) + assert first_norm == second_norm elif isinstance(first, dict): for ((fk,fv), (sk,sv)) in zip(first.items(), second.items()): self._round_trip_compare_fn(fk, sk) self._round_trip_compare_fn(fv, sv) else: - self.assertEqual(first,second) + assert first == second def _round_trip_test(self, data, ctype_str): ctype = parse_casstype_args(ctype_str) data_bytes = ctype.serialize(data, 0) serialized_size = ctype.subtype.serial_size() if serialized_size: - self.assertEqual(serialized_size * len(data), len(data_bytes)) + assert serialized_size * len(data) == len(data_bytes) result = ctype.deserialize(data_bytes, 0) - self.assertEqual(len(data), len(result)) + assert len(data) == len(result) for idx in range(0,len(data)): self._round_trip_compare_fn(data[idx], result[idx]) @@ -460,60 +453,60 @@ def test_round_trip_vector_of_vectors(self): def test_cql_parameterized_type(self): # Base vector functionality ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType") + assert ctype.cql_parameterized_type() == "org.apache.cassandra.db.marshal.VectorType" # Test vector-of-vectors inner_type = "org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)" ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(%s, 3)" % (inner_type)) inner_parsed_type = "org.apache.cassandra.db.marshal.VectorType" - self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type)) + assert ctype.cql_parameterized_type() == "org.apache.cassandra.db.marshal.VectorType<%s, 3>" % (inner_parsed_type) def test_serialization_fixed_size_too_small(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") - with self.assertRaisesRegex(ValueError, "Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4"): + with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type float and dimension 5, observed sequence of length 4"): ctype.serialize([1.2, 3.4, 5.6, 7.8], 0) def test_serialization_fixed_size_too_big(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - with self.assertRaisesRegex(ValueError, "Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5"): + with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type float and dimension 4, observed sequence of length 5"): ctype.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) def test_serialization_variable_size_too_small(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") - with self.assertRaisesRegex(ValueError, "Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4"): + with pytest.raises(ValueError, match="Expected sequence of size 5 for vector of type varint and dimension 5, observed sequence of length 4"): ctype.serialize([1, 2, 3, 4], 0) def test_serialization_variable_size_too_big(self): ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") - with self.assertRaisesRegex(ValueError, "Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5"): + with pytest.raises(ValueError, match="Expected sequence of size 4 for vector of type varint and dimension 4, observed sequence of length 5"): ctype.serialize([1, 2, 3, 4, 5], 0) def test_deserialization_fixed_size_too_small(self): ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") ctype_four_bytes = ctype_four.serialize([1.2, 3.4, 5.6, 7.8], 0) ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") - with self.assertRaisesRegex(ValueError, "Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead"): + with pytest.raises(ValueError, match="Expected vector of type float and dimension 5 to have serialized size 20; observed serialized size of 16 instead"): ctype_five.deserialize(ctype_four_bytes, 0) def test_deserialization_fixed_size_too_big(self): ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 5)") ctype_five_bytes = ctype_five.serialize([1.2, 3.4, 5.6, 7.8, 9.10], 0) ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)") - with self.assertRaisesRegex(ValueError, "Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead"): + with pytest.raises(ValueError, match="Expected vector of type float and dimension 4 to have serialized size 16; observed serialized size of 20 instead"): ctype_four.deserialize(ctype_five_bytes, 0) def test_deserialization_variable_size_too_small(self): ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") ctype_four_bytes = ctype_four.serialize([1, 2, 3, 4], 0) ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") - with self.assertRaisesRegex(ValueError, "Error reading additional data during vector deserialization after successfully adding 4 elements"): + with pytest.raises(ValueError, match="Error reading additional data during vector deserialization after successfully adding 4 elements"): ctype_five.deserialize(ctype_four_bytes, 0) def test_deserialization_variable_size_too_big(self): ctype_five = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 5)") ctype_five_bytes = ctype_five.serialize([1, 2, 3, 4, 5], 0) ctype_four = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.IntegerType, 4)") - with self.assertRaisesRegex(ValueError, "Additional bytes remaining after vector deserialization completed"): + with pytest.raises(ValueError, match="Additional bytes remaining after vector deserialization completed"): ctype_four.deserialize(ctype_five_bytes, 0) @@ -553,7 +546,7 @@ def test_month_rounding_creation_failure(self): dr = DateRange(OPEN_BOUND, DateRangeBound(feb_stamp, DateRangePrecision.MONTH)) dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) - self.assertEqual(dt.day, 28) + assert dt.day == 28 # Leap year feb_stamp_leap_year = ms_timestamp_from_datetime( @@ -562,32 +555,29 @@ def test_month_rounding_creation_failure(self): dr = DateRange(OPEN_BOUND, DateRangeBound(feb_stamp_leap_year, DateRangePrecision.MONTH)) dt = datetime_from_timestamp(dr.upper_bound.milliseconds / 1000) - self.assertEqual(dt.day, 29) + assert dt.day == 29 def test_decode_precision(self): - self.assertEqual(DateRangeType._decode_precision(6), 'MILLISECOND') + assert DateRangeType._decode_precision(6) == 'MILLISECOND' def test_decode_precision_error(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): DateRangeType._decode_precision(-1) def test_encode_precision(self): - self.assertEqual(DateRangeType._encode_precision('SECOND'), 5) + assert DateRangeType._encode_precision('SECOND') == 5 def test_encode_precision_error(self): - with self.assertRaises(ValueError): + with pytest.raises(ValueError): DateRangeType._encode_precision('INVALID') def test_deserialize_single_value(self): serialized = (int8_pack(0) + int64_pack(self.timestamp) + int8_pack(3)) - self.assertEqual( - DateRangeType.deserialize(serialized, 5), - util.DateRange(value=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), - precision='HOUR') - ) + assert DateRangeType.deserialize(serialized, 5) == util.DateRange(value=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), + precision='HOUR') ) def test_deserialize_closed_range(self): @@ -596,17 +586,14 @@ def test_deserialize_closed_range(self): int8_pack(2) + int64_pack(self.timestamp) + int8_pack(6)) - self.assertEqual( - DateRangeType.deserialize(serialized, 5), - util.DateRange( - lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 0, 0), - precision='DAY' - ), - upper_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), - precision='MILLISECOND' - ) + assert DateRangeType.deserialize(serialized, 5) == util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 0, 0), + precision='DAY' + ), + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12, 404000), + precision='MILLISECOND' ) ) @@ -615,15 +602,12 @@ def test_deserialize_open_high(self): int64_pack(self.timestamp) + int8_pack(3)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 0), - precision='HOUR' - ), - upper_bound=util.OPEN_BOUND - ) + assert deserialized == util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 0), + precision='HOUR' + ), + upper_bound=util.OPEN_BOUND ) def test_deserialize_open_low(self): @@ -631,35 +615,26 @@ def test_deserialize_open_low(self): int64_pack(self.timestamp) + int8_pack(4)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.OPEN_BOUND, - upper_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 20, 1000), - precision='MINUTE' - ) + assert deserialized == util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 20, 1000), + precision='MINUTE' ) ) def test_deserialize_single_open(self): - self.assertEqual( - util.DateRange(value=util.OPEN_BOUND), - DateRangeType.deserialize(int8_pack(5), 5) - ) + assert util.DateRange(value=util.OPEN_BOUND) == DateRangeType.deserialize(int8_pack(5), 5) def test_serialize_single_value(self): serialized = (int8_pack(0) + int64_pack(self.timestamp) + int8_pack(5)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - value=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12), - precision='SECOND' - ) + assert deserialized == util.DateRange( + value=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12), + precision='SECOND' ) ) @@ -670,17 +645,14 @@ def test_serialize_closed_range(self): int64_pack(self.timestamp) + int8_pack(0)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15, 42, 12), - precision='SECOND' - ), - upper_bound=util.DateRangeBound( - value=datetime.datetime(2017, 12, 31), - precision='YEAR' - ) + assert deserialized == util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15, 42, 12), + precision='SECOND' + ), + upper_bound=util.DateRangeBound( + value=datetime.datetime(2017, 12, 31), + precision='YEAR' ) ) @@ -689,15 +661,12 @@ def test_serialize_open_high(self): int64_pack(self.timestamp) + int8_pack(2)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1), - precision='DAY' - ), - upper_bound=util.OPEN_BOUND - ) + assert deserialized == util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1), + precision='DAY' + ), + upper_bound=util.OPEN_BOUND ) def test_serialize_open_low(self): @@ -705,57 +674,50 @@ def test_serialize_open_low(self): int64_pack(self.timestamp) + int8_pack(3)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.DateRangeBound( - value=datetime.datetime(2017, 2, 1, 15), - precision='HOUR' - ), - upper_bound=util.OPEN_BOUND - ) + assert deserialized == util.DateRange( + lower_bound=util.DateRangeBound( + value=datetime.datetime(2017, 2, 1, 15), + precision='HOUR' + ), + upper_bound=util.OPEN_BOUND ) def test_deserialize_both_open(self): serialized = (int8_pack(4)) deserialized = DateRangeType.deserialize(serialized, 5) - self.assertEqual( - deserialized, - util.DateRange( - lower_bound=util.OPEN_BOUND, - upper_bound=util.OPEN_BOUND - ) + assert deserialized == util.DateRange( + lower_bound=util.OPEN_BOUND, + upper_bound=util.OPEN_BOUND ) def test_serialize_single_open(self): serialized = DateRangeType.serialize(util.DateRange( value=util.OPEN_BOUND, ), 5) - self.assertEqual(int8_pack(5), serialized) + assert int8_pack(5) == serialized def test_serialize_both_open(self): serialized = DateRangeType.serialize(util.DateRange( lower_bound=util.OPEN_BOUND, upper_bound=util.OPEN_BOUND ), 5) - self.assertEqual(int8_pack(4), serialized) + assert int8_pack(4) == serialized def test_failure_to_serialize_no_value_object(self): - self.assertRaises(ValueError, DateRangeType.serialize, object(), 5) + with pytest.raises(ValueError): + DateRangeType.serialize(object(), 5) def test_failure_to_serialize_no_bounds_object(self): class no_bounds_object(object): value = lower_bound = None - self.assertRaises(ValueError, DateRangeType.serialize, no_bounds_object, 5) + with pytest.raises(ValueError): + DateRangeType.serialize(no_bounds_object, 5) def test_serialized_value_round_trip(self): vals = [b'\x01\x00\x00\x01%\xe9a\xf9\xd1\x06\x00\x00\x01v\xbb>o\xff\x00', b'\x01\x00\x00\x00\xdcm\x03-\xd1\x06\x00\x00\x01v\xbb>o\xff\x00'] for serialized in vals: - self.assertEqual( - serialized, - DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) - ) + assert serialized == DateRangeType.serialize(DateRangeType.deserialize(serialized, 0), 0) def test_serialize_zero_datetime(self): """ @@ -826,8 +788,8 @@ def test_deserialize_date_range_milliseconds(self): upper_value = self.starting_upper_value + i dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.MILLISECOND), DateRangeBound(upper_value, DateRangePrecision.MILLISECOND)) - self.assertEqual(lower_value, dr.lower_bound.milliseconds) - self.assertEqual(upper_value, dr.upper_bound.milliseconds) + assert lower_value == dr.lower_bound.milliseconds + assert upper_value == dr.upper_bound.milliseconds def test_deserialize_date_range_seconds(self): """ @@ -852,9 +814,9 @@ def truncate_last_figures(number, n=3): dr = DateRange(DateRangeBound(lower_value, DateRangePrecision.SECOND), DateRangeBound(upper_value, DateRangePrecision.SECOND)) - self.assertEqual(truncate_last_figures(lower_value), dr.lower_bound.milliseconds) + assert truncate_last_figures(lower_value) == dr.lower_bound.milliseconds upper_value = truncate_last_figures(upper_value) + 999 - self.assertEqual(upper_value, dr.upper_bound.milliseconds) + assert upper_value == dr.upper_bound.milliseconds def test_deserialize_date_range_minutes(self): """ @@ -1021,9 +983,9 @@ def truncate_date(number): DateRangeBound(upper_value, precision)) # We verify that rounded value corresponds with what we would expect - self.assertEqual(truncate_date(lower_value), dr.lower_bound.milliseconds) + assert truncate_date(lower_value) == dr.lower_bound.milliseconds upper_value = round_up_truncated_upper_value(truncate_date(upper_value)) - self.assertEqual(upper_value, dr.upper_bound.milliseconds) + assert upper_value == dr.upper_bound.milliseconds class TestOrdering(unittest.TestCase): @@ -1045,9 +1007,9 @@ def test_host_order(self): hosts_equal = [Host(addr, SimpleConvictionPolicy) for addr in ("127.0.0.1", "127.0.0.1")] hosts_equal_conviction = [Host("127.0.0.1", SimpleConvictionPolicy), Host("127.0.0.1", ConvictionPolicy)] - check_sequence_consistency(self, hosts) - check_sequence_consistency(self, hosts_equal, equal=True) - check_sequence_consistency(self, hosts_equal_conviction, equal=True) + check_sequence_consistency(hosts) + check_sequence_consistency(hosts_equal, equal=True) + check_sequence_consistency(hosts_equal_conviction, equal=True) def test_date_order(self): """ @@ -1061,8 +1023,8 @@ def test_date_order(self): """ dates_from_string = [Date("2017-01-01"), Date("2017-01-05"), Date("2017-01-09"), Date("2017-01-13")] dates_from_string_equal = [Date("2017-01-01"), Date("2017-01-01")] - check_sequence_consistency(self, dates_from_string) - check_sequence_consistency(self, dates_from_string_equal, equal=True) + check_sequence_consistency(dates_from_string) + check_sequence_consistency(dates_from_string_equal, equal=True) date_format = "%Y-%m-%d" @@ -1072,15 +1034,15 @@ def test_date_order(self): for dtstr in ("2017-01-02", "2017-01-06", "2017-01-10", "2017-01-14") ] dates_from_value_equal = [Date(1), Date(1)] - check_sequence_consistency(self, dates_from_value) - check_sequence_consistency(self, dates_from_value_equal, equal=True) + check_sequence_consistency(dates_from_value) + check_sequence_consistency(dates_from_value_equal, equal=True) dates_from_datetime = [Date(datetime.datetime.strptime(dtstr, date_format)) for dtstr in ("2017-01-03", "2017-01-07", "2017-01-11", "2017-01-15")] dates_from_datetime_equal = [Date(datetime.datetime.strptime("2017-01-01", date_format)), Date(datetime.datetime.strptime("2017-01-01", date_format))] - check_sequence_consistency(self, dates_from_datetime) - check_sequence_consistency(self, dates_from_datetime_equal, equal=True) + check_sequence_consistency(dates_from_datetime) + check_sequence_consistency(dates_from_datetime_equal, equal=True) dates_from_date = [ Date(datetime.datetime.strptime(dtstr, date_format).date()) for dtstr in @@ -1089,10 +1051,10 @@ def test_date_order(self): dates_from_date_equal = [datetime.datetime.strptime(dtstr, date_format) for dtstr in ("2017-01-09", "2017-01-9")] - check_sequence_consistency(self, dates_from_date) - check_sequence_consistency(self, dates_from_date_equal, equal=True) + check_sequence_consistency(dates_from_date) + check_sequence_consistency(dates_from_date_equal, equal=True) - check_sequence_consistency(self, self._shuffle_lists(dates_from_string, dates_from_value, + check_sequence_consistency(self._shuffle_lists(dates_from_string, dates_from_value, dates_from_datetime, dates_from_date)) def test_timer_order(self): @@ -1107,23 +1069,23 @@ def test_timer_order(self): """ time_from_int = [Time(1000), Time(4000), Time(7000), Time(10000)] time_from_int_equal = [Time(1), Time(1)] - check_sequence_consistency(self, time_from_int) - check_sequence_consistency(self, time_from_int_equal, equal=True) + check_sequence_consistency(time_from_int) + check_sequence_consistency(time_from_int_equal, equal=True) time_from_datetime = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) for us in (2, 5, 8, 11)] time_from_datetime_equal = [Time(datetime.time(hour=0, minute=0, second=0, microsecond=us)) for us in (1, 1)] - check_sequence_consistency(self, time_from_datetime) - check_sequence_consistency(self, time_from_datetime_equal, equal=True) + check_sequence_consistency(time_from_datetime) + check_sequence_consistency(time_from_datetime_equal, equal=True) time_from_string = [Time("00:00:00.000003000"), Time("00:00:00.000006000"), Time("00:00:00.000009000"), Time("00:00:00.000012000")] time_from_string_equal = [Time("00:00:00.000004000"), Time("00:00:00.000004000")] - check_sequence_consistency(self, time_from_string) - check_sequence_consistency(self, time_from_string_equal, equal=True) + check_sequence_consistency(time_from_string) + check_sequence_consistency(time_from_string_equal, equal=True) - check_sequence_consistency(self, self._shuffle_lists(time_from_int, time_from_datetime, time_from_string)) + check_sequence_consistency(self._shuffle_lists(time_from_int, time_from_datetime, time_from_string)) def test_token_order(self): """ @@ -1137,5 +1099,5 @@ def test_token_order(self): """ tokens = [Token(1), Token(2), Token(3), Token(4)] tokens_equal = [Token(1), Token(1)] - check_sequence_consistency(self, tokens) - check_sequence_consistency(self, tokens_equal, equal=True) + check_sequence_consistency(tokens) + check_sequence_consistency(tokens_equal, equal=True) diff --git a/tests/unit/test_util_types.py b/tests/unit/test_util_types.py index a2551ba20b..4a115affbc 100644 --- a/tests/unit/test_util_types.py +++ b/tests/unit/test_util_types.py @@ -16,6 +16,7 @@ import datetime from cassandra.util import Date, Time, Duration, Version, maybe_add_timeout_to_query +import pytest class DateTests(unittest.TestCase): @@ -23,57 +24,57 @@ class DateTests(unittest.TestCase): def test_from_datetime(self): expected_date = datetime.date(1492, 10, 12) d = Date(expected_date) - self.assertEqual(str(d), str(expected_date)) + assert str(d) == str(expected_date) def test_from_string(self): expected_date = datetime.date(1492, 10, 12) d = Date(expected_date) sd = Date('1492-10-12') - self.assertEqual(sd, d) + assert sd == d sd = Date('+1492-10-12') - self.assertEqual(sd, d) + assert sd == d def test_from_date(self): expected_date = datetime.date(1492, 10, 12) d = Date(expected_date) - self.assertEqual(d.date(), expected_date) + assert d.date() == expected_date def test_from_days(self): sd = Date(0) - self.assertEqual(sd, Date(datetime.date(1970, 1, 1))) + assert sd == Date(datetime.date(1970, 1, 1)) sd = Date(-1) - self.assertEqual(sd, Date(datetime.date(1969, 12, 31))) + assert sd == Date(datetime.date(1969, 12, 31)) sd = Date(1) - self.assertEqual(sd, Date(datetime.date(1970, 1, 2))) + assert sd == Date(datetime.date(1970, 1, 2)) def test_limits(self): min_builtin = Date(datetime.date(1, 1, 1)) max_builtin = Date(datetime.date(9999, 12, 31)) - self.assertEqual(Date(min_builtin.days_from_epoch), min_builtin) - self.assertEqual(Date(max_builtin.days_from_epoch), max_builtin) + assert Date(min_builtin.days_from_epoch) == min_builtin + assert Date(max_builtin.days_from_epoch) == max_builtin # just proving we can construct with on offset outside buildin range - self.assertEqual(Date(min_builtin.days_from_epoch - 1).days_from_epoch, - min_builtin.days_from_epoch - 1) - self.assertEqual(Date(max_builtin.days_from_epoch + 1).days_from_epoch, - max_builtin.days_from_epoch + 1) + assert Date(min_builtin.days_from_epoch - 1).days_from_epoch == min_builtin.days_from_epoch - 1 + assert Date(max_builtin.days_from_epoch + 1).days_from_epoch == max_builtin.days_from_epoch + 1 def test_invalid_init(self): - self.assertRaises(ValueError, Date, '-1999-10-10') - self.assertRaises(TypeError, Date, 1.234) + with pytest.raises(ValueError): + Date('-1999-10-10') + with pytest.raises(TypeError): + Date(1.234) def test_str(self): date_str = '2015-03-16' - self.assertEqual(str(Date(date_str)), date_str) + assert str(Date(date_str)) == date_str def test_out_of_range(self): - self.assertEqual(str(Date(2932897)), '2932897') - self.assertEqual(repr(Date(1)), 'Date(1)') + assert str(Date(2932897)) == '2932897' + assert repr(Date(1)) == 'Date(1)' def test_equals(self): - self.assertEqual(Date(1234), 1234) - self.assertEqual(Date(1), datetime.date(1970, 1, 2)) - self.assertFalse(Date(2932897) == datetime.date(9999, 12, 31)) # date can't represent year > 9999 - self.assertEqual(Date(2932897), 2932897) + assert Date(1234) == 1234 + assert Date(1) == datetime.date(1970, 1, 2) + assert not Date(2932897) == datetime.date(9999, 12, 31) # date can't represent year > 9999 + assert Date(2932897) == 2932897 class TimeTests(unittest.TestCase): @@ -86,31 +87,31 @@ def test_units_from_string(self): one_hour = 60 * one_minute tt = Time('00:00:00.000000001') - self.assertEqual(tt.nanosecond_time, 1) + assert tt.nanosecond_time == 1 tt = Time('00:00:00.000001') - self.assertEqual(tt.nanosecond_time, one_micro) + assert tt.nanosecond_time == one_micro tt = Time('00:00:00.001') - self.assertEqual(tt.nanosecond_time, one_milli) + assert tt.nanosecond_time == one_milli tt = Time('00:00:01') - self.assertEqual(tt.nanosecond_time, one_second) + assert tt.nanosecond_time == one_second tt = Time('00:01:00') - self.assertEqual(tt.nanosecond_time, one_minute) + assert tt.nanosecond_time == one_minute tt = Time('01:00:00') - self.assertEqual(tt.nanosecond_time, one_hour) + assert tt.nanosecond_time == one_hour tt = Time('01:00:00.') - self.assertEqual(tt.nanosecond_time, one_hour) + assert tt.nanosecond_time == one_hour tt = Time('23:59:59.123456') - self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro) + assert tt.nanosecond_time == 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro tt = Time('23:59:59.1234567') - self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 700) + assert tt.nanosecond_time == 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 700 tt = Time('23:59:59.12345678') - self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 780) + assert tt.nanosecond_time == 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 780 tt = Time('23:59:59.123456789') - self.assertEqual(tt.nanosecond_time, 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 789) + assert tt.nanosecond_time == 23 * one_hour + 59 * one_minute + 59 * one_second + 123 * one_milli + 456 * one_micro + 789 def test_micro_precision(self): Time('23:59:59.1') @@ -121,32 +122,36 @@ def test_micro_precision(self): def test_from_int(self): tt = Time(12345678) - self.assertEqual(tt.nanosecond_time, 12345678) + assert tt.nanosecond_time == 12345678 def test_from_time(self): expected_time = datetime.time(12, 1, 2, 3) tt = Time(expected_time) - self.assertEqual(tt, expected_time) + assert tt == expected_time def test_as_time(self): expected_time = datetime.time(12, 1, 2, 3) tt = Time(expected_time) - self.assertEqual(tt.time(), expected_time) + assert tt.time() == expected_time def test_equals(self): # util.Time self equality - self.assertEqual(Time(1234), Time(1234)) + assert Time(1234) == Time(1234) def test_str_repr(self): time_str = '12:13:14.123456789' - self.assertEqual(str(Time(time_str)), time_str) - self.assertEqual(repr(Time(1)), 'Time(1)') + assert str(Time(time_str)) == time_str + assert repr(Time(1)) == 'Time(1)' def test_invalid_init(self): - self.assertRaises(ValueError, Time, '1999-10-10 11:11:11.1234') - self.assertRaises(TypeError, Time, 1.234) - self.assertRaises(ValueError, Time, 123456789000000) - self.assertRaises(TypeError, Time, datetime.datetime(2004, 12, 23, 11, 11, 1)) + with pytest.raises(ValueError): + Time('1999-10-10 11:11:11.1234') + with pytest.raises(TypeError): + Time(1.234) + with pytest.raises(ValueError): + Time(123456789000000) + with pytest.raises(TypeError): + Time(datetime.datetime(2004, 12, 23, 11, 11, 1)) class DurationTests(unittest.TestCase): @@ -154,53 +159,53 @@ class DurationTests(unittest.TestCase): def test_valid_format(self): valid = Duration(1, 1, 1) - self.assertEqual(valid.months, 1) - self.assertEqual(valid.days, 1) - self.assertEqual(valid.nanoseconds, 1) + assert valid.months == 1 + assert valid.days == 1 + assert valid.nanoseconds == 1 valid = Duration(nanoseconds=100000) - self.assertEqual(valid.months, 0) - self.assertEqual(valid.days, 0) - self.assertEqual(valid.nanoseconds, 100000) + assert valid.months == 0 + assert valid.days == 0 + assert valid.nanoseconds == 100000 valid = Duration() - self.assertEqual(valid.months, 0) - self.assertEqual(valid.days, 0) - self.assertEqual(valid.nanoseconds, 0) + assert valid.months == 0 + assert valid.days == 0 + assert valid.nanoseconds == 0 valid = Duration(-10, -21, -1000) - self.assertEqual(valid.months, -10) - self.assertEqual(valid.days, -21) - self.assertEqual(valid.nanoseconds, -1000) + assert valid.months == -10 + assert valid.days == -21 + assert valid.nanoseconds == -1000 def test_equality(self): first = Duration(1, 1, 1) second = Duration(-1, 1, 1) - self.assertNotEqual(first, second) + assert first != second first = Duration(1, 1, 1) second = Duration(1, 1, 1) - self.assertEqual(first, second) + assert first == second first = Duration() second = Duration(0, 0, 0) - self.assertEqual(first, second) + assert first == second first = Duration(1000, 10000, 2345345) second = Duration(1000, 10000, 2345345) - self.assertEqual(first, second) + assert first == second first = Duration(12, 0 , 100) second = Duration(nanoseconds=100, months=12) - self.assertEqual(first, second) + assert first == second def test_str(self): - self.assertEqual(str(Duration(1, 1, 1)), "1mo1d1ns") - self.assertEqual(str(Duration(1, 1, -1)), "-1mo1d1ns") - self.assertEqual(str(Duration(1, 1, 1000000000000000)), "1mo1d1000000000000000ns") - self.assertEqual(str(Duration(52, 23, 564564)), "52mo23d564564ns") + assert str(Duration(1, 1, 1)) == "1mo1d1ns" + assert str(Duration(1, 1, -1)) == "-1mo1d1ns" + assert str(Duration(1, 1, 1000000000000000)) == "1mo1d1000000000000000ns" + assert str(Duration(52, 23, 564564)) == "52mo23d564564ns" class VersionTests(unittest.TestCase): @@ -223,79 +228,73 @@ def test_version_parsing(self): for str_version, expected_result in versions: v = Version(str_version) - self.assertEqual(str_version, str(v)) - self.assertEqual(v.major, expected_result[0]) - self.assertEqual(v.minor, expected_result[1]) - self.assertEqual(v.patch, expected_result[2]) - self.assertEqual(v.build, expected_result[3]) - self.assertEqual(v.prerelease, expected_result[4]) + assert str_version == str(v) + assert v.major == expected_result[0] + assert v.minor == expected_result[1] + assert v.patch == expected_result[2] + assert v.build == expected_result[3] + assert v.prerelease == expected_result[4] # not supported version formats - with self.assertRaises(ValueError): + with pytest.raises(ValueError): Version('test.1.0') def test_version_compare(self): # just tests a bunch of versions # major wins - self.assertTrue(Version('3.3.0') > Version('2.5.0')) - self.assertTrue(Version('3.3.0') > Version('2.5.0.66')) - self.assertTrue(Version('3.3.0') > Version('2.5.21')) + assert Version('3.3.0') > Version('2.5.0') + assert Version('3.3.0') > Version('2.5.0.66') + assert Version('3.3.0') > Version('2.5.21') # minor wins - self.assertTrue(Version('2.3.0') > Version('2.2.0')) - self.assertTrue(Version('2.3.0') > Version('2.2.7')) - self.assertTrue(Version('2.3.0') > Version('2.2.7.9')) + assert Version('2.3.0') > Version('2.2.0') + assert Version('2.3.0') > Version('2.2.7') + assert Version('2.3.0') > Version('2.2.7.9') # patch wins - self.assertTrue(Version('2.3.1') > Version('2.3.0')) - self.assertTrue(Version('2.3.1') > Version('2.3.0.4post0')) - self.assertTrue(Version('2.3.1') > Version('2.3.0.44')) + assert Version('2.3.1') > Version('2.3.0') + assert Version('2.3.1') > Version('2.3.0.4post0') + assert Version('2.3.1') > Version('2.3.0.44') # various - self.assertTrue(Version('2.3.0.1') > Version('2.3.0.0')) - self.assertTrue(Version('2.3.0.680') > Version('2.3.0.670')) - self.assertTrue(Version('2.3.0.681') > Version('2.3.0.680')) - self.assertTrue(Version('2.3.0.1build0') > Version('2.3.0.1')) # 4th part fallback to str cmp - self.assertTrue(Version('2.3.0.build0') > Version('2.3.0.1')) # 4th part fallback to str cmp - self.assertTrue(Version('2.3.0') < Version('2.3.0.build')) - - self.assertTrue(Version('4-a') <= Version('4.0.0')) - self.assertTrue(Version('4-a') <= Version('4.0-alpha1')) - self.assertTrue(Version('4-a') <= Version('4.0-beta1')) - self.assertTrue(Version('4.0.0') >= Version('4.0.0')) - self.assertTrue(Version('4.0.0.421') >= Version('4.0.0')) - self.assertTrue(Version('4.0.1') >= Version('4.0.0')) - self.assertTrue(Version('2.3.0') == Version('2.3.0')) - self.assertTrue(Version('2.3.32') == Version('2.3.32')) - self.assertTrue(Version('2.3.32') == Version('2.3.32.0')) - self.assertTrue(Version('2.3.0.build') == Version('2.3.0.build')) - - self.assertTrue(Version('4') == Version('4.0.0')) - self.assertTrue(Version('4.0') == Version('4.0.0.0')) - self.assertTrue(Version('4.0') > Version('3.9.3')) - - self.assertTrue(Version('4.0') > Version('4.0-SNAPSHOT')) - self.assertTrue(Version('4.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) - self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0-SNAPSHOT')) - self.assertTrue(Version('4.0.0-SNAPSHOT') == Version('4.0.0-SNAPSHOT')) - self.assertTrue(Version('4.0.0.build5-SNAPSHOT') == Version('4.0.0.build5-SNAPSHOT')) - self.assertTrue(Version('4.1-SNAPSHOT') > Version('4.0-SNAPSHOT')) - self.assertTrue(Version('4.0.1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) - self.assertTrue(Version('4.0.0.build6-SNAPSHOT') > Version('4.0.0.build5-SNAPSHOT')) - self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0-SNAPSHOT1')) - self.assertTrue(Version('4.0-SNAPSHOT2') > Version('4.0.0-SNAPSHOT1')) - - self.assertTrue(Version('4.0.0-alpha1-SNAPSHOT') > Version('4.0.0-SNAPSHOT')) + assert Version('2.3.0.1') > Version('2.3.0.0') + assert Version('2.3.0.680') > Version('2.3.0.670') + assert Version('2.3.0.681') > Version('2.3.0.680') + assert Version('2.3.0.1build0') > Version('2.3.0.1') # 4th part fallback to str cmp + assert Version('2.3.0.build0') > Version('2.3.0.1') # 4th part fallback to str cmp + assert Version('2.3.0') < Version('2.3.0.build') + + assert Version('4-a') <= Version('4.0.0') + assert Version('4-a') <= Version('4.0-alpha1') + assert Version('4-a') <= Version('4.0-beta1') + assert Version('4.0.0') >= Version('4.0.0') + assert Version('4.0.0.421') >= Version('4.0.0') + assert Version('4.0.1') >= Version('4.0.0') + assert Version('2.3.0') == Version('2.3.0') + assert Version('2.3.32') == Version('2.3.32') + assert Version('2.3.32') == Version('2.3.32.0') + assert Version('2.3.0.build') == Version('2.3.0.build') + + assert Version('4') == Version('4.0.0') + assert Version('4.0') == Version('4.0.0.0') + assert Version('4.0') > Version('3.9.3') + + assert Version('4.0') > Version('4.0-SNAPSHOT') + assert Version('4.0-SNAPSHOT') == Version('4.0-SNAPSHOT') + assert Version('4.0.0-SNAPSHOT') == Version('4.0-SNAPSHOT') + assert Version('4.0.0-SNAPSHOT') == Version('4.0.0-SNAPSHOT') + assert Version('4.0.0.build5-SNAPSHOT') == Version('4.0.0.build5-SNAPSHOT') + assert Version('4.1-SNAPSHOT') > Version('4.0-SNAPSHOT') + assert Version('4.0.1-SNAPSHOT') > Version('4.0.0-SNAPSHOT') + assert Version('4.0.0.build6-SNAPSHOT') > Version('4.0.0.build5-SNAPSHOT') + assert Version('4.0-SNAPSHOT2') > Version('4.0-SNAPSHOT1') + assert Version('4.0-SNAPSHOT2') > Version('4.0.0-SNAPSHOT1') + + assert Version('4.0.0-alpha1-SNAPSHOT') > Version('4.0.0-SNAPSHOT') class FunctionTests(unittest.TestCase): def test_maybe_add_timeout_to_query(self): - self.assertEqual( - "SELECT * FROM HOSTS", - maybe_add_timeout_to_query("SELECT * FROM HOSTS", None) - ) - self.assertEqual( - "SELECT * FROM HOSTS USING TIMEOUT 1000ms", - maybe_add_timeout_to_query("SELECT * FROM HOSTS", datetime.timedelta(seconds=1)) - ) + assert "SELECT * FROM HOSTS" == maybe_add_timeout_to_query("SELECT * FROM HOSTS", None) + assert "SELECT * FROM HOSTS USING TIMEOUT 1000ms" == maybe_add_timeout_to_query("SELECT * FROM HOSTS", datetime.timedelta(seconds=1)) diff --git a/tests/unit/util.py b/tests/unit/util.py index e57fa6c3ee..603eb4d9b5 100644 --- a/tests/unit/util.py +++ b/tests/unit/util.py @@ -11,20 +11,20 @@ # limitations under the License. -def check_sequence_consistency(unit_test, ordered_sequence, equal=False): +def check_sequence_consistency(ordered_sequence, equal=False): for i, el in enumerate(ordered_sequence): for previous in ordered_sequence[:i]: - _check_order_consistency(unit_test, previous, el, equal) + _check_order_consistency(previous, el, equal) for posterior in ordered_sequence[i + 1:]: - _check_order_consistency(unit_test, el, posterior, equal) + _check_order_consistency(el, posterior, equal) -def _check_order_consistency(unit_test, smaller, bigger, equal=False): - unit_test.assertLessEqual(smaller, bigger) - unit_test.assertGreaterEqual(bigger, smaller) +def _check_order_consistency(smaller, bigger, equal=False): + assert smaller <= bigger + assert bigger >= smaller if equal: - unit_test.assertEqual(smaller, bigger) + assert smaller == bigger else: - unit_test.assertNotEqual(smaller, bigger) - unit_test.assertLess(smaller, bigger) - unit_test.assertGreater(bigger, smaller) + assert smaller != bigger + assert smaller < bigger + assert bigger > smaller diff --git a/tests/util.py b/tests/util.py index 5c7ac2416f..2439e20fd5 100644 --- a/tests/util.py +++ b/tests/util.py @@ -11,9 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - + import time from functools import wraps +import re +import unittest +import difflib +import pytest def wait_until(condition, delay, max_attempts): @@ -72,3 +76,38 @@ def wrapper(*args, **kwargs): func(*args, **kwargs) return wrapper return decorator + +def assertRegex(text: str, pattern: str): + assert re.search(pattern, text) + +unittest_test_case = unittest.TestCase() + +def assertSequenceEqual(a, b, seq_type = None): + unittest_test_case.assertSequenceEqual(a, b, seq_type=seq_type) + +def assertDictEqual(a, b): + assertSequenceEqual(a, b, seq_type=dict) + +def assertListEqual(a, b): + assertSequenceEqual(a, b, seq_type=list) + +def assertSetEqual(a, b): + assertSequenceEqual(a, b, seq_type=set) + +def assertCountEqual(a, b): + unittest_test_case.assertCountEqual(a, b) + +def assertEqual(a, b): + assert a == b + +def assert_startswith_diff(text, prefix): + if not text.startswith(prefix): + prefix_lines = prefix.split('\n') + diff_string = '\n'.join(difflib.unified_diff(prefix_lines, + text.split('\n')[:len(prefix_lines)], + 'EXPECTED', 'RECEIVED', + lineterm='')) + pytest.fail(diff_string) + +def assertIsInstance(a, b): + assert isinstance(a, b)