Skip to content

Commit 5c3ebdf

Browse files
authored
Merge pull request #1892 from jeromekelleher/better-error-general-stat
Better error message in general stat
2 parents 510e515 + 6caea97 commit 5c3ebdf

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

‎python/_tskitmodule.c‎

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8192,13 +8192,23 @@ general_stat_func(tsk_size_t K, const double *X, tsk_size_t M, double *Y, void *
81928192
goto out;
81938193
}
81948194
Y_array = (PyArrayObject *) PyArray_FromAny(
8195-
result, PyArray_DescrFromType(NPY_FLOAT64), 1, 1, NPY_ARRAY_IN_ARRAY, NULL);
8195+
result, PyArray_DescrFromType(NPY_FLOAT64), 0, 0, NPY_ARRAY_IN_ARRAY, NULL);
81968196
if (Y_array == NULL) {
81978197
goto out;
81988198
}
8199+
if (PyArray_NDIM(Y_array) != 1) {
8200+
PyErr_Format(PyExc_ValueError,
8201+
"Array returned by general_stat callback is %d dimensional; "
8202+
"must be 1D",
8203+
(int) PyArray_NDIM(Y_array));
8204+
goto out;
8205+
}
81998206
Y_dims = PyArray_DIMS(Y_array);
82008207
if (Y_dims[0] != (npy_intp) M) {
8201-
PyErr_SetString(PyExc_ValueError, "Incorrect callback output dimensions");
8208+
PyErr_Format(PyExc_ValueError,
8209+
"Array returned by general_stat callback is of length %d; "
8210+
"must be %d",
8211+
Y_dims[0], M);
82028212
goto out;
82038213
}
82048214
/* Copy the contents of the return Y array into Y */

‎python/tests/test_tree_stats.py‎

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6150,3 +6150,38 @@ def test_uncalibrated_time_general_stat(self, ts_fixture):
61506150
W.shape[1],
61516151
mode="branch",
61526152
)
6153+
6154+
6155+
class TestGeneralStatCallbackErrors:
6156+
def test_zero_d(self, ts_fixture):
6157+
def f_0d(_):
6158+
return 0
6159+
6160+
msg = "Array returned by general_stat callback is 0 dimensional; must be 1D"
6161+
with pytest.raises(ValueError, match=msg):
6162+
ts_fixture.sample_count_stat(
6163+
sample_sets=[ts_fixture.samples()], f=f_0d, output_dim=1, strict=False
6164+
)
6165+
6166+
def test_two_d(self, ts_fixture):
6167+
def f_2d(x):
6168+
return np.array([x])
6169+
6170+
msg = "Array returned by general_stat callback is 2 dimensional; must be 1D"
6171+
with pytest.raises(ValueError, match=msg):
6172+
ts_fixture.sample_count_stat(
6173+
sample_sets=[ts_fixture.samples()], f=f_2d, output_dim=1, strict=False
6174+
)
6175+
6176+
def test_wrong_length(self, ts_fixture):
6177+
def f_too_long(_):
6178+
return np.array([0, 0])
6179+
6180+
msg = "Array returned by general_stat callback is of length 2; must be 1"
6181+
with pytest.raises(ValueError, match=msg):
6182+
ts_fixture.sample_count_stat(
6183+
sample_sets=[ts_fixture.samples()],
6184+
f=f_too_long,
6185+
output_dim=1,
6186+
strict=False,
6187+
)

0 commit comments

Comments
 (0)