Skip to content

Commit cb65f54

Browse files
committed
tests: visitors and speed. added sympy requirement to test
1 parent f144bae commit cb65f54

File tree

3 files changed

+65
-53
lines changed

3 files changed

+65
-53
lines changed

requirements/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ codecov
44
coverage
55
pytest-cov
66
pytest-env
7+
sympy

tests/test_speed.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,21 @@
1414
##############################################################################
1515
"""Tests for refinableobj module."""
1616

17-
from __future__ import print_function
18-
1917
import random
2018

2119
import numpy
2220

2321
import diffpy.srfit.equation.literals as literals
2422
import diffpy.srfit.equation.visitors as visitors
2523

26-
from .utils import _makeArgs
27-
2824
x = numpy.arange(0, 20, 0.05)
2925

3026

31-
def makeLazyEquation():
27+
def makeLazyEquation(make_args):
3228
"""Make a lazy equation and see how fast it is."""
3329

3430
# Make some variables
35-
v1, v2, v3, v4, v5, v6, v7 = _makeArgs(7)
31+
v1, v2, v3, v4, v5, v6, v7 = make_args(7)
3632

3733
# Make some operations
3834
mult = literals.MultiplicationOperator()
@@ -463,14 +459,17 @@ def profileTest():
463459
return
464460

465461

466-
if __name__ == "__main__":
467-
for i in range(1, 13):
468-
speedTest2(i)
469-
"""
470-
for i in range(1, 9):
471-
weightedTest(i)
472-
"""
473-
"""From diffpy.srfit.equation.builder import EquationFactory import random
474-
import cProfile cProfile.run('profileTest()', 'prof') import pstats p =
475-
pstats.Stats('prof') p.strip_dirs() p.sort_stats('time') p.print_stats(10)
476-
profileTest()"""
462+
# if __name__ == "__main__":
463+
# for i in range(1, 13):
464+
# speedTest2(i)
465+
# """
466+
# for i in range(1, 9):
467+
# weightedTest(i)
468+
# """
469+
# """From diffpy.srfit.equation.builder import
470+
# EquationFactory import random
471+
# import cProfile cProfile.run('profileTest()', 'prof')
472+
# import pstats p =
473+
# pstats.Stats('prof') p.strip_dirs() p.sort_stats('time')
474+
# p.print_stats(10)
475+
# profileTest()"""

tests/test_visitors.py

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,22 @@
1616

1717
import unittest
1818

19+
import pytest
20+
1921
import diffpy.srfit.equation.literals as literals
2022
import diffpy.srfit.equation.visitors as visitors
2123

2224

23-
class TestValidator(unittest.TestCase):
25+
class TestValidator:
26+
@pytest.fixture(autouse=True)
27+
def setup(self, make_args):
28+
self.make_args = make_args
2429

2530
def testSimpleFunction(self):
2631
"""Test a simple function."""
2732

2833
# Make some variables
29-
v1, v2, v3, v4 = _makeArgs(4)
34+
v1, v2, v3, v4 = self.make_args(4)
3035

3136
# Make some operations
3237
mult = literals.MultiplicationOperator()
@@ -49,50 +54,53 @@ def testSimpleFunction(self):
4954
# Now validate
5055
validator = visitors.Validator()
5156
mult.identify(validator)
52-
self.assertEqual(4, len(validator.errors))
57+
assert 4 == len(validator.errors)
5358

5459
# Fix the equation
5560
minus.addLiteral(v3)
5661
validator.reset()
5762
mult.identify(validator)
58-
self.assertEqual(3, len(validator.errors))
63+
assert 3 == len(validator.errors)
5964

6065
# Fix the name of plus
6166
plus.name = "add"
6267
validator.reset()
6368
mult.identify(validator)
64-
self.assertEqual(2, len(validator.errors))
69+
assert 2 == len(validator.errors)
6570

6671
# Fix the symbol of plus
6772
plus.symbol = "+"
6873
validator.reset()
6974
mult.identify(validator)
70-
self.assertEqual(1, len(validator.errors))
75+
assert 1 == len(validator.errors)
7176

7277
# Fix the operation of plus
7378
import numpy
7479

7580
plus.operation = numpy.add
7681
validator.reset()
7782
mult.identify(validator)
78-
self.assertEqual(0, len(validator.errors))
83+
assert 0 == len(validator.errors)
7984

8085
# Add another literal to minus
8186
minus.addLiteral(v1)
8287
validator.reset()
8388
mult.identify(validator)
84-
self.assertEqual(1, len(validator.errors))
89+
assert 1 == len(validator.errors)
8590

8691
return
8792

8893

89-
class TestArgFinder(unittest.TestCase):
94+
class TestArgFinder:
95+
@pytest.fixture(autouse=True)
96+
def setup(self, make_args):
97+
self.make_args = make_args
9098

9199
def testSimpleFunction(self):
92100
"""Test a simple function."""
93101

94102
# Make some variables
95-
v1, v2, v3, v4 = _makeArgs(4)
103+
v1, v2, v3, v4 = self.make_args(4)
96104

97105
# Make some operations
98106
mult = literals.MultiplicationOperator()
@@ -116,33 +124,36 @@ def testSimpleFunction(self):
116124

117125
# now get the args
118126
args = visitors.getArgs(mult)
119-
self.assertEqual(4, len(args))
120-
self.assertTrue(v1 in args)
121-
self.assertTrue(v2 in args)
122-
self.assertTrue(v3 in args)
123-
self.assertTrue(v4 in args)
127+
assert 4 == len(args)
128+
assert v1 in args
129+
assert v2 in args
130+
assert v3 in args
131+
assert v4 in args
124132

125133
return
126134

127135
def testArg(self):
128136
"""Test just an Argument equation."""
129137
# Make some variables
130-
v1 = _makeArgs(1)[0]
138+
v1 = self.make_args(1)[0]
131139

132140
args = visitors.getArgs(v1)
133141

134-
self.assertEqual(1, len(args))
135-
self.assertTrue(args[0] is v1)
142+
assert 1 == len(args)
143+
assert args[0] == v1
136144
return
137145

138146

139-
class TestSwapper(unittest.TestCase):
147+
class TestSwapper:
148+
@pytest.fixture(autouse=True)
149+
def setup(self, make_args):
150+
self.make_args = make_args
140151

141152
def testSimpleFunction(self):
142153
"""Test a simple function."""
143154

144155
# Make some variables
145-
v1, v2, v3, v4, v5 = _makeArgs(5)
156+
v1, v2, v3, v4, v5 = self.make_args(5)
146157

147158
# Make some operations
148159
mult = literals.MultiplicationOperator()
@@ -166,43 +177,44 @@ def testSimpleFunction(self):
166177
v5.setValue(5)
167178

168179
# Evaluate
169-
self.assertEqual(8, mult.value)
180+
assert 8 == mult.value
170181

171182
# Now swap an argument
172183
visitors.swap(mult, v2, v5)
173184

174185
# Check that the operator value is invalidated
175-
self.assertTrue(mult._value is None)
176-
self.assertFalse(v2.hasObserver(minus._flush))
177-
self.assertTrue(v5.hasObserver(minus._flush))
186+
assert mult._value is None
187+
assert not v2.hasObserver(minus._flush)
188+
assert v5.hasObserver(minus._flush)
178189

179190
# now get the args
180191
args = visitors.getArgs(mult)
181-
self.assertEqual(4, len(args))
182-
self.assertTrue(v1 in args)
183-
self.assertTrue(v2 not in args)
184-
self.assertTrue(v3 in args)
185-
self.assertTrue(v4 in args)
186-
self.assertTrue(v5 in args)
192+
assert 4 == len(args)
193+
assert v1 in args
194+
assert v2 not in args
195+
assert v3 in args
196+
assert v4 in args
197+
assert v5 in args
187198

188199
# Re-evaluate (1+3)*(4-5) = -4
189-
self.assertEqual(-4, mult.value)
200+
assert -4 == mult.value
190201

191202
# Swap out the "-" operator
192203
plus2 = literals.AdditionOperator()
193204
visitors.swap(mult, minus, plus2)
194-
self.assertTrue(mult._value is None)
195-
self.assertFalse(minus.hasObserver(mult._flush))
196-
self.assertTrue(plus2.hasObserver(mult._flush))
205+
assert mult._value is None
206+
assert not minus.hasObserver(mult._flush)
207+
assert plus2.hasObserver(mult._flush)
197208

198209
# plus2 has no arguments yet. Verify this.
199-
self.assertRaises(TypeError, mult.getValue)
210+
with pytest.raises(TypeError):
211+
mult.getValue()
200212
# Add the arguments to plus2.
201213
plus2.addLiteral(v4)
202214
plus2.addLiteral(v5)
203215

204216
# Re-evaluate (1+3)*(4+5) = 36
205-
self.assertEqual(36, mult.value)
217+
assert 36 == mult.value
206218

207219
return
208220

0 commit comments

Comments
 (0)