Unit testing#

This notebook introduces the basic concepts of testing programs. It focuses on simplicity and presents some basic techniques guides through important aspects to consider when writing tests.

Please note that automatic tests are usually not placed in notebooks but in separate modules. This makes them easier to organise and run.

We will use pytest, an easy to use 3rd party tool for writing and running tests in Python. In addition, we will use some functionality from the testing module of Numpy. There are many alternatives to those tools, most notably the unittest package that comes with the Python standard library.

Magnetic Quantum Numbers#

We begin with a simple example.

We want to implement a function which computes the allowed magnetic quantum numbers \(m\) for any given angular momentum \(j\). Remember that \(m = -j, -j+1, \ldots , j-1, j\). This function is implemented below.

Note:

The code samples are not always written in a clean way but are sometimes overly convoluted in order to garnish them with hidden bugs. If you are looking for advice on how to write good Python, look elsewhere!

import numpy as np
def magnetic_quantum_numbers(j):
    """
    Returns a list of all magnetic quantum numbers for a given angular momentum j.
    The results are sorted in ascending order.
    """
    if j < 0:
        raise ValueError("j must be greater equal 0")
    negative = np.arange(-j, 0)
    positive = -negative[::-1]
    return list(np.concatenate((negative, [0], positive)))

In order to test this function, we can compare its output to what we would expect. For example:

  • \(j = 0 \rightarrow m = 0\)

  • \(j = 1 \rightarrow m = 0, \pm1\)

According to its docstring, magnetic_quantum_numbers promises to return a list of numbers sorted in ascending order. This implies that for \(j=0,1\), we expect to get [0] and [-1, 0, 1], respectively.

We can encode tests for these by using assertions. The assert keyword in Python takes a condition and, if that condition is false, raises an AssertionError.

assert magnetic_quantum_numbers(0) == [0]
assert magnetic_quantum_numbers(1) == [-1, 0, 1]

No output was produced, which means that both assertions passed. Our tests were successful!

If this is not clear enough for you, feel free to add a print statement at the end of a cell to notify you that the tests passed, see below.

Those tests are great and all, but two tests are hardly enough. In particular, we have only used integer values for \(j\). Let’s try a half-integer now:

assert magnetic_quantum_numbers(0.5) == [-0.5, 0.5]
print("Success!")
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 assert magnetic_quantum_numbers(0.5) == [-0.5, 0.5]
      2 print("Success!")

AssertionError: 

This is not so great anymore. What you can see above (if you executed the cell) is one of the aforementioned AssertionErrors.

Unfortunately, it does not provide any useful information other than letting us know that the test failed.

We will later introduce methods for getting better outputs for failed tests. For now, just look at the result of our function manually:

magnetic_quantum_numbers(0.5)
[-0.5, 0.0, 0.5]

There is an extra 0 in the list. This was fine for integer \(j\), but it should not be there for half-integers. We will fix magnetic_quantum_numbers later, but for now we leave it as it is.

Error Modes#

Until now, we have only looked at the ‘happy path’ of our function, meaning that all inputs are valid and the computation goes through to the end, albeit sometimes incorrectly.

But it is equally important to test the error paths. If a function does not detect failures or invalid inputs, it might silently produce garbage results that can be very hard to notice.


Aside: Exceptions#

In Python, errors are almost always reported in the form of exceptions. Here, you do not need to understand exactly how they work. Put simply, in order to signal (‘raise’) an error, we write

raise RuntimeError("message")

We can pass any message we want to describe the error. There are multiple predefined error types in Python. RuntimeError is a general error that we can use in most situations. When a function argument is invalid, one usually raises a ValueError.


magnetic_quantum_numbers has two conditions for its input:

  • \(j \geq 0\)

  • \(j\) is an integer or half integer

We thus expect our function to raise a ValueError if either condition is violated. In order to test for that we use the external package pytest which provides many utilities for writing tests. Here, we need the context manager pytest.raises. Again, there is no need to understand what a context manager is. For our purposes, simply read

with pytest.raises(ValueError):

as ‘check that the indented code below it raises a ValueError’.

import pytest

with pytest.raises(ValueError):
    magnetic_quantum_numbers(-1)
with pytest.raises(ValueError):
    magnetic_quantum_numbers(0.2)
---------------------------------------------------------------------------
Failed                                    Traceback (most recent call last)
Cell In[6], line 5
      3 with pytest.raises(ValueError):
      4     magnetic_quantum_numbers(-1)
----> 5 with pytest.raises(ValueError):
      6     magnetic_quantum_numbers(0.2)

    [... skipping hidden 1 frame]

File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/_pytest/outcomes.py:178, in fail(reason, pytrace)
    165 """Explicitly fail an executing test with the given message.
    166 
    167 :param reason:
   (...)
    175     The exception that is raised.
    176 """
    177 __tracebackhide__ = True
--> 178 raise Failed(msg=reason, pytrace=pytrace)

Failed: DID NOT RAISE <class 'ValueError'>

Another failure! But look at the error message. Only the second test failed by not raising a ValueError. The first one succeeded.

Looking at the implementation of magnetic_quantum_numbers above shows that there is indeed a check for \(j < 0\) but none for the integer condition.

Below, you can find a complete implementation which passes all of our tests. Feel free to try it out and write more tests.

def magnetic_quantum_numbers(j):
    if j < 0:
        raise ValueError("j must be greater equal 0")
    if (2 * j) != int(2 * j):
        raise ValueError("j must be an integer or half integer")
    return list(np.arange(-j, j + 1, 1))


with pytest.raises(ValueError):
    magnetic_quantum_numbers(-1)
with pytest.raises(ValueError):
    magnetic_quantum_numbers(0.2)

Zeeman Splitting#

We can now use our function to compute physical observables. As an example, we are going to look at the Zeeman effect. Let us start by reminding ourselves of the formula.

The coupling of the spin and orbital angular momentum of an electron to an external magnetic field \(B\) shifts the energy level of that electron by

\[ \Delta E = \mu_B g_j m_j B , \]

where \(\mu_B\) is the Bohr Magneton, \(g_j\) the Landé factor, and \(m_j\) the magnetic quantum number that we computed above. For simplicity, we only look at energy shifts relative to the magnetic field and omit \(B\) in the following.

First, we define a function which computes all energy shifts for a given set of angular momenta:

def zeeman_shifts(j, l, s):
    lande = 1 + (j * (j + 1) - l * (l + 1) + s * (s + 1)) / (2 * j * (j + 1))
    magneton = physical_constants["Bohr magneton in eV/T"][0]
    magn = magnetic_quantum_numbers(j)
    return [magneton * m * lande for m in magn]

We are now faced with a difficult problem. We have a piece of Python code which implements a mathematical equation. But how can we test whether those two match? Before, we had a simple rule for what the output should be. This made it easy to write some tests which compare the actual to the expected output. However now, it seems as though we need the output of our function in order to know what result it should produce in the first place.

Luckily though, people have performed many calculations based on the Zeeman effect already. So we can pick one or more of those results and compare out implementation with them.

Let us choose the Lyman-α line, meaning the transition of an electron from an \(n = 2\) orbital to \(n=1\) in hydrogen. The impact of the Zeeman effect on this transition depends on the angular momentum, of course. For a transition from \(|n, l, j, m_j\rangle = |2, 1, \frac{1}{2}, +\frac{1}{2}\rangle\) to \(|1, 0, \frac{1}{2}, +\frac{1}{2}\rangle\), the Zeeman effect induced energy difference is \(\Delta E = - \frac{2}{3} \mu_B\). So let us use that as a test: (We have electrons, so \(s = \frac{1}{2}\) is fixed.)

from scipy.constants import physical_constants

# The `[0]` in physical_constants['name'][0] extracts the actual value.
bohr_magneton = physical_constants["Bohr magneton in eV/T"][0]
before = zeeman_shifts(1 / 2, 1, 1 / 2)
after = zeeman_shifts(1 / 2, 0, 1 / 2)
assert before[1] - after[1] == -2 / 3 * bohr_magneton

The test passes! Note that we are using before[1] and after[1], which extract the \(m_j = +\frac{1}{2}\) elements according to

magnetic_quantum_numbers(1 / 2)
[-0.5, 0.5]

As a nice bonus, we are loading the Bohr magneton from SciPy with a specific unit. This lets us check if our function uses the units we expect.

There are a lot more transitions we can use. For instance, \(|2, 1, \frac{3}{2}, +\frac{1}{2}\rangle\) to \(|1, 0, \frac{1}{2}, -\frac{1}{2}\rangle\) has an energy of \(\frac{5}{3} \mu_B\):

before = zeeman_shifts(3 / 2, 1, 1 / 2)
after = zeeman_shifts(1 / 2, 0, 1 / 2)
assert before[2] - after[0] == 5 / 3 * bohr_magneton
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[11], line 3
      1 before = zeeman_shifts(3 / 2, 1, 1 / 2)
      2 after = zeeman_shifts(1 / 2, 0, 1 / 2)
----> 3 assert before[2] - after[0] == 5 / 3 * bohr_magneton

AssertionError: 

What happened this time? Again, the failure message does not give us a lot of information. So let’s output the results ourselves:

print(before[2] - after[0])
print(5 / 3 * bohr_magneton)
9.647303009999999e-05
9.64730301e-05

Pretty close but not an exact match! The problem here is that we are performing calculations with floating point numbers. Those have limited precision, and we cannot expect two equivalent calculations to produce exactly the same results unless they do exactly the same operations in exactly the same order.

This means that we should not use the equality operator (==) as above but rather test for approximate equality. A useful tool for this is the testing module of numpy. It provides assert_almost_equal which does what we need:

np.testing.assert_almost_equal(before[2] - after[0], 5 / 3 * bohr_magneton)

Now there is no output meaning that the test passes.

We could keep going and add more tests, and you should do this in practice. But these kinds of tests are often of limited use because we generally do not have known outputs that we can easily compare to.

A different approach (potentially supplementing the above) is searching for simple cases where we know the correct result. One such case would be a particle without spin. Here, the interaction with the magnetic field depends only on the orbital angular momentum and should thus give an energy shift of \(m_j \mu_B\). And indeed, our implementation satisfies this:

j = 1  # l = j here
magn = magnetic_quantum_numbers(j)
delta_E = zeeman_shifts(j, j, 0)
for i in range(len(magn)):
    np.testing.assert_almost_equal(
        delta_E[i], magn[i] * physical_constants["Bohr magneton in eV/T"][0]
    )

Yet another testing ansatz is searching for properties of the expected results that we can test without knowing the actual numerical values. An example is the length of the output. We know that zeeman_shifts has to produce a list with the same number of elements as magnetic_quantum_numbers. This gives us an opportunity for testing with a broad range of different inputs.

Write tests comparing the lengths below. You can use the simple assert here because lengths are integers and can compare exactly equal. You can simply list a bunch of possible inputs j, l, s or, for extra points, write a loop / loops to iterate through several different values. Note that not all combinations are physically allowed: \(j = l+s, l+s-1, \ldots , |l-s|\)

Solution:

Hide code cell content
test_momenta = ((1, 1, 0), (1.5, 0, 1.5), (1.5, 1, 0.5), (0.5, 2, 1.5))
for j, l, s in test_momenta:
    assert len(magnetic_quantum_numbers(j)) == len(zeeman_shifts(j, l, s))

Another property is that \(\Delta E\) is proportional to \(m_j\). This is a simple relationship that we can exploit! We can do so in two ways:

  • \(m_j\) always appears in pairs with opposite sign except for \(m_j=0\). The Zeeman shifts thus appear in pairs as well. We can write tests that match the corresponding elements. Mind the order of results produced by magnetic_quantum_numbers.

  • Alternatively, we divide out \(m_j\) which should give us the same value for every list element returned by zeeman_shifts. There is one exception that needs to be handled differently: \(m_j = 0\).

As an exercise, implement one or both of these tests. While the second approach is more thorough, the first one can still be a good exercise. It might help to convert the list returned by zeeman_shifts to a numpy array first.

Solution:

Hide code cell content
# point 1
for j, l, s in test_momenta:
    delta_E = zeeman_shifts(j, l, s)
    for i in range(1, len(delta_E)):
        print(i, len(delta_E) - i, len(delta_E))
        np.testing.assert_almost_equal(delta_E[i], -delta_E[-i - 1])

# point 2
for j, l, s in test_momenta:
    delta_E = zeeman_shifts(j, l, s)
    m_j = magnetic_quantum_numbers(j)
    reference = delta_E[0] / m_j[0]
    for i in range(1, len(delta_E)):
        if m_j[i] == 0:
            np.testing.assert_almost_equal(delta_E[i], 0)
        else:
            np.testing.assert_almost_equal(delta_E[i] / m_j[i], reference)
1 2 3
2 1 3
1 3 4
2 2 4
3 1 4
1 3 4
2 2 4
3 1 4
1 1 2

A last test before we are done with the Zeeman effect. Earlier, we talked about how important it is to also test the error modes of our code. So let’s do this! zeeman_shifts has the same conditions on j as magnetic_quantum_numbers:

with pytest.raises(ValueError):
    zeeman_shifts(0.2, 1, 0)
with pytest.raises(ValueError):
    zeeman_shifts(-1, 1, 0)
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
Cell In[18], line 2
      1 with pytest.raises(ValueError):
----> 2     zeeman_shifts(-1, 1, 0)

Cell In[8], line 2, in zeeman_shifts(j, l, s)
      1 def zeeman_shifts(j, l, s):
----> 2     lande = 1 + (j * (j + 1) - l * (l + 1) + s * (s + 1)) / (2 * j * (j + 1))
      3     magneton = physical_constants["Bohr magneton in eV/T"][0]
      4     magn = magnetic_quantum_numbers(j)

ZeroDivisionError: division by zero

We see that zeeman_shifts inherited the check for values of \(j\) that are not integers or half-integers. But passing in \(j=-1\) triggers a different error which tells us something about the implementation of our function but is not very informative to a user. (Try out \(j=-2\), too!) We can remedy the situation by adding a corresponding check to the beginning of zeeman_shifts. Feel free to do this as an exercise.

But there are more potential errors. As already mentioned, there are physical restrictions on the allowed function arguments: \(j = l+s, l+s-1, \ldots , |l-s|\) It is evident from the source code of zeeman_shifts that those are not verified. As an exercise, you can implement such a check in zeeman_shifts and add a corresponding test here:

Solution:

Hide code cell content
def zeeman_shifts(j, l, s):
    if l + s < j or j < np.abs(l - s):
        raise ValueError(f"Angular momenta out of range: j={j}, l={l}, s={s}")
    if l != int(l):
        raise ValueError(f"Orbital angular momentum is not an integer: {l}")
    if s == int(s) and j != int(j):
        raise ValueError(
            f"Spin is integer but total angular momentum is not: s={s}, j={j}"
        )
    if s != int(s) and 2 * s == int(2 * s) and j == int(j):
        raise ValueError(
            f"Spin is half integer but total angular momentum is integer: s={s}, j={j}"
        )
    lande = 1 + (j * (j + 1) - l * (l + 1) + s * (s + 1)) / (2 * j * (j + 1))
    magneton = physical_constants["Bohr magneton in eV/T"][0]
    magn = magnetic_quantum_numbers(j)
    return [magneton * m * lande for m in magn]


# out of range
with pytest.raises(ValueError):
    zeeman_shifts(-1, 0, 0)
with pytest.raises(ValueError):
    zeeman_shifts(-1, 4, 2)
with pytest.raises(ValueError):
    zeeman_shifts(1, 0, 0)
with pytest.raises(ValueError):
    zeeman_shifts(3, 1, 0.5)
# orbital integer check
with pytest.raises(ValueError):
    zeeman_shifts(0.5, 0.5, 0)
# j is incorrect category
with pytest.raises(ValueError):
    zeeman_shifts(1, 1, 0.5)
with pytest.raises(ValueError):
    zeeman_shifts(1.5, 1, 1)

Exercises#

You can do the exercises in this section in any order you like.

As mentioned in the beginning, tests are usually not written in separate modules instead of notebooks. All exercises are available in this notebook as well as test_exercises.py. The latter shows how you can organise tests in a module using pytest. Note, however, that the descriptions in this notebook are more thorough and contain nicely formatted equations.

If you stick with the notebook, you might want to use an alternative to the assert keyword that provides more information on failure. You can just replace

assert a == b

with

np.testing.assert_equal(a, b)

Hint:

It is generally a good idea to perform ‘black box testing’, that is, writing tests without modifying or, where possible, even looking at the code that we test. The examples in this notebook contain an implementation of the functions that we want to test. Avoid reading those if possible and focus on the descriptions given in text or function docstrings.

Fibonacci#

The Fibonacci sequence is a common example for a recursive function. Here we indeed have a recursive implementation that we want to test. This can, of course, be done with a number of examples that you can pick manually.

However, there is a different approach. We can also compute the Fibonacci numbers using

\[ x_n = \frac{\varphi^n - {(1 - \varphi)}^n}{\sqrt{5}} \]

and compare that with the recursive implementation. \(\varphi \approx 1.61803399\) is the golden ratio. You can get a high-precision approximation of it from scipy:

from scipy.constants import golden

Make sure to convert the result of the above equation to an integer or use an approximate comparison when comparing the two implementations.

def fibonacci(n):
    """
    Return the nth Fibonacci number.
    """
    if n in (0, 1):
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)

Solution:

Hide code cell content
def fibonacci(n):
    """
    Return the nth Fibonacci number.
    """
    if n < 0:
        raise ValueError("Input must be a positive integer.")
    if n in (0, 1):
        return n
    return fibonacci(n - 1) + fibonacci(n - 2)


from scipy.constants import golden


def fib_golden(n):
    return int((golden**n - (1 - golden) ** n) / np.sqrt(5))


for i in range(20):
    assert fibonacci(i) == fib_golden(i)

with pytest.raises(ValueError):
    fibonacci(-1)

Leap Years#

A leap year is any year that is divisible by 4 unless it is also divisible by 100 and not by 400. The author of the function tried to be clever, but is it actually correct?

def is_leapyear(year):
    """
    Return True if year is a leap year, False otherwise.
    """
    return year % 400 == year % 100 + year % 4 != 0

Solution:

Hide code cell content
assert not is_leapyear(1)
assert not is_leapyear(2)
assert not is_leapyear(3)
assert is_leapyear(4)
assert not is_leapyear(100)
assert is_leapyear(400)
assert is_leapyear(2020)
assert is_leapyear(2000)
assert not is_leapyear(1000)
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[23], line 6
      4 assert is_leapyear(4)
      5 assert not is_leapyear(100)
----> 6 assert is_leapyear(400)
      7 assert is_leapyear(2020)
      8 assert is_leapyear(2000)

AssertionError: 

Histogramming#

Here, we have a simple function for building histograms. When testing it, remember to think of possible edge cases. Also, are all conditions on the arguments validated? And of course, you can always compute the results for some small cases manually.

def histogram(data, bin_edges):
    """
    Construct a histogram of one dimensional data.

    Arguments:
      data: 1D array.
      bin_edges: (len=nbins+1) 1D array of bin edges.
                 Must be monotonically increasing.
                 Bin i is in range [bin_edges[i], bin_edges[i+1]].
    """
    hist = np.zeros(shape=[len(bin_edges) - 1], dtype=int)
    for value in data:
        for i in range(1, len(bin_edges)):
            if value < bin_edges[i]:
                hist[i - 1] += 1
                break
    return hist

Solution:

Hide code cell content
edges = np.array([0, 1, 2, 3])
# empty input
np.testing.assert_array_equal(histogram(np.array([]), edges), [0, 0, 0])
# input outside of bins
np.testing.assert_array_equal(histogram(np.array([4, 3]), edges), [0, 0, 0])
# manual examples
np.testing.assert_array_equal(
    histogram(np.array([0, 0.1, 0.6, 2, 2.2, 0.2]), edges), [4, 0, 2]
)
np.testing.assert_array_equal(
    histogram(np.array([-0.5, 0.1, 1.2, 2.2, 2.2, 0.3]), np.array([-1, 1, 2, 3])),
    [3, 1, 2],
)
np.testing.assert_array_equal(
    histogram(np.array([-0.5, 0.1, 1.2, 2.2, 2.2, 0.3]), edges), [2, 1, 2]
)
# not monotonically increasing bins
with pytest.raises(ValueError):
    histogram(np.array([0, 0.1, 0.6, 2, 2.2, 0.2]), np.array([0, 1, 3, 2]))
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[25], line 14
      7 np.testing.assert_array_equal(
      8     histogram(np.array([0, 0.1, 0.6, 2, 2.2, 0.2]), edges), [4, 0, 2]
      9 )
     10 np.testing.assert_array_equal(
     11     histogram(np.array([-0.5, 0.1, 1.2, 2.2, 2.2, 0.3]), np.array([-1, 1, 2, 3])),
     12     [3, 1, 2],
     13 )
---> 14 np.testing.assert_array_equal(
     15     histogram(np.array([-0.5, 0.1, 1.2, 2.2, 2.2, 0.3]), edges), [2, 1, 2]
     16 )
     17 # not monotonically increasing bins
     18 with pytest.raises(ValueError):

    [... skipping hidden 1 frame]

File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/numpy/testing/_private/utils.py:797, in assert_array_compare(comparison, x, y, err_msg, verbose, header, precision, equal_nan, equal_inf, strict)
    793         err_msg += '\n' + '\n'.join(remarks)
    794         msg = build_err_msg([ox, oy], err_msg,
    795                             verbose=verbose, header=header,
    796                             names=('x', 'y'), precision=precision)
--> 797         raise AssertionError(msg)
    798 except ValueError:
    799     import traceback

AssertionError: 
Arrays are not equal

Mismatched elements: 1 / 3 (33.3%)
Max absolute difference: 1
Max relative difference: 0.5
 x: array([3, 1, 2])
 y: array([2, 1, 2])

Chi Squared#

The goodness of a fit is often measured using \(\chi^2\). It is defined as the sum of square differences of model and measurement data: $\( \chi^2 = \sum_{i=1}^{N}\, \frac{{(y_{\text{model},\,i} - y_{\text{meas},\,i})}^2}{\sigma^{2}_{i}} \)$ Things to consider: Does the function reproduce small examples? Are there specific inputs that make for very simple outputs? Are there any symmetries that you can test? Are there invalid arguments?

def chi_squared(model, meas, errors):
    return np.sum(((model - meas) / errors) ** 2)

Solution:

Hide code cell content
model = np.array([1, 4, 2, 3])
meas = np.array([2, 4, 4, 2])
errors = np.array([1, 2, 2, 1])
# special cases
np.testing.assert_almost_equal(chi_squared(model, model, errors), 0)
np.testing.assert_almost_equal(
    chi_squared(model, np.zeros(len(model)), np.ones(len(model))), np.sum(model**2)
)
# symmetry in model and meas
np.testing.assert_almost_equal(
    chi_squared(model, meas, errors), chi_squared(meas, model, errors)
)
# manually computed result
np.testing.assert_almost_equal(chi_squared(model, meas, errors), 3)

# length mismatches
with pytest.raises(ValueError):
    chi_squared(model, [1, 2], errors)
with pytest.raises(ValueError):
    chi_squared([1, 2, 3], meas, errors)
with pytest.raises(ValueError):
    chi_squared(model, meas, [1, 2, 3, 4, 5])

Least Squares Fit#

Note:

This is a potentially longer and more complicated exercise, feel free to skip it.

In this exercise, we are looking at fitting a parabola using the method of least squares. There is a fundamental difference to the other exercises. Here, we generally cannot sensibly construct an expected output because the fit has to deal with randomness. We can still test a number of special cases, though. Additionally, in cases like this one, it can make sense to do some (at least partially) manual tests. For example, we can perform a fit and plot the result to see if it makes sense.

def fit_parabola(x, y_meas, errors):
    """
    Perform a least squares fit of a parabola.

    Arguments:
      x: Independent variable.
      y_meas: measured values of the dependent variable.
      errors: Uncertainties of the measured values.

    Returns:
      Best fit results for the parameters of the parabola.
    """
    X = np.array([np.ones_like(x), x, x**2]).T
    V = np.diag(1 / errors**2)
    return np.linalg.inv(X.T @ V @ X) @ X.T @ V @ y_meas

Solution:

Hide code cell content
import matplotlib.pyplot as plt


def eval_parabola(x, params):
    return params[0] + params[1] * x + params[2] * x**2


# clean parabola
x = np.linspace(-2, 4, 10)
true_params = np.array([0.1, 1.2, -2.4])
y_meas = eval_parabola(x, true_params)
errors = np.full(len(y_meas), 0.01)
best_fit = fit_parabola(x, y_meas, errors)
np.testing.assert_almost_equal(true_params, best_fit)

# mismatch in argument sizes
with pytest.raises(ValueError):
    fit_parabola(x, y_meas, np.array([1, 2]))
with pytest.raises(ValueError):
    fit_parabola(x, np.zeros(len(x) + 1), errors)
with pytest.raises(ValueError):
    fit_parabola(np.zeros(len(x) + 1), y_meas, errors)

# noisy data
x = np.linspace(-2, 4, 10)
true_params = np.array([-0.4, 2.2, 3.14])
errors = np.random.uniform(0.5, 3, len(x))
y_meas = np.random.normal(eval_parabola(x, true_params), errors)
best_fit = fit_parabola(x, y_meas, errors)
# does not work!
# np.testing.assert_almost_equal(true_params, best_fit)

# plot the result to eyeball it
plt.errorbar(x, y_meas, errors, ls="", marker=".")
plt.plot(x, eval_parabola(x, best_fit))
[<matplotlib.lines.Line2D at 0x7fe384de6980>]
../../_images/59c3e96704e09e27d8cd4c5d6eac8cfbcbc2c4d7ba599ae621d2b95c34421cbe.png

Energy Transfer#

The energy transfer in inelastic scattering of neutrons is

  • Direct: $\( \Delta E_{\text{direct}} = E_i - \frac{m_{\text{n}} L_2^2}{2 (t - t_0)}\, , \quad t_0 = \sqrt{\frac{L_1^2 m_\text{n}}{E_i}} \)$

  • Indirect: $\( \Delta E_{\text{indirect}} = \frac{m_{\text{n}} L_1^2}{2 (t - t_0)} - E_f\, , \quad t_0 = \sqrt{\frac{L_2^2 m_\text{n}}{E_f}} \)$

where \(E_i\), \(E_f\) are the initial and final energy, respectively. \(t\) is the time-of-flight, \(m_{\text{n}}\) is the neutron mass, and \(L_1\) and \(L_2\) are the lengths of the primary and secondary flight paths.

As usual, try to find special cases where the equations become simpler and you can easily predict the output. Also, write tests with incorrect / nonsense input to check if the function reports errors properly. Mind the units given in the docstring. It is important to pick inputs with a good order of magnitude. Otherwise, we will have problems with floating point precision. You might need to increase the required precision of the assertions. This can be done like so:

np.testing.assert_almost_equal(a, b, decimal=20)

see https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_almost_equal.html#numpy.testing.assert_almost_equal.

from scipy.constants import neutron_mass
def energy_transfer(ei_or_ef, tof, L1, L2, mode):
    """
    Compute the energy transfer for inelastic neutron scattering.

    Units in square brackets.

    Arguments:
      ei_or_ef: [J] In direct scattering: the initial energy.
                    In indirect scattering: the final energy.
      tof: [s] Time-of-flight.
      L1: [m] Primary flight path.
      L2: [m] Secondary flight path.
      mode: Either 'direct' or 'indirect'.
    """
    if mode == "direct":
        t0 = np.sqrt(L1**2 * neutron_mass / ei_or_ef)
    elif mode == "indirect":
        t0 = np.sqrt(L2**2 * neutron_mass / ei_or_ef)
    delta_t = tof - t0
    return ei_or_ef - neutron_mass * L2**2 / 2 / delta_t**2

Solution:

Hide code cell content
from scipy.constants import physical_constants

ei = 1000 * physical_constants["electron volt"][0]  # J
tof = 1e-6  # s
mn = neutron_mass  # kg
np.testing.assert_almost_equal(energy_transfer(ei, tof, 3.2, 0, "direct"), ei)
assert energy_transfer(ei, tof, 1, 0.5, "direct") < ei
np.testing.assert_almost_equal(
    energy_transfer(ei, 1, 0, 1, "direct"), ei - neutron_mass / 2
)

ef = 23 * physical_constants["electron volt"][0]  # J
np.testing.assert_almost_equal(energy_transfer(ef, tof, 0, 0.5, "indirect"), ef)
np.testing.assert_almost_equal(
    energy_transfer(ef, 1, 1, 0, "indirect"), neutron_mass / 2 - ef, decimal=20
)

with pytest.raises(ValueError):
    energy_transfer(ei, tof, 2, 1, "invalid")
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[32], line 14
     12 ef = 23 * physical_constants["electron volt"][0]  # J
     13 np.testing.assert_almost_equal(energy_transfer(ef, tof, 0, 0.5, "indirect"), ef)
---> 14 np.testing.assert_almost_equal(
     15     energy_transfer(ef, 1, 1, 0, "indirect"), neutron_mass / 2 - ef, decimal=20
     16 )
     18 with pytest.raises(ValueError):
     19     energy_transfer(ei, tof, 2, 1, "invalid")

File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File /opt/hostedtoolcache/Python/3.10.14/x64/lib/python3.10/site-packages/numpy/testing/_private/utils.py:537, in assert_almost_equal(actual, desired, decimal, err_msg, verbose)
    535     pass
    536 if abs(desired - actual) >= np.float64(1.5 * 10.0**(-decimal)):
--> 537     raise AssertionError(_build_err_msg())

AssertionError: 
Arrays are not almost equal to 20 decimals
 ACTUAL: 3.6850062582e-18
 DESIRED: -3.685006257362536e-18