diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index 6f9debb7..98221f40 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -763,7 +763,7 @@ def plot_one_hist( results: ratapi.outputs.BayesResults, param: int | str, smooth: bool = True, - sigma: float | None = None, + window_size: int = 8, estimated_density: Literal["normal", "lognor", "kernel", None] = None, axes: Axes | None = None, block: bool = False, @@ -783,9 +783,9 @@ def plot_one_hist( smooth : bool, default True Whether to apply Gaussian smoothing to the histogram. Defaults to True. - sigma: float or None, default None - If given, is used as the sigma-parameter for the Gaussian smoothing. - If None, the default (1/3rd of parameter chain standard deviation) is used. + window_size : int, default 8 + The width of the smoothing window centered around the element being averaged. + The window moves down the length of the data, computing an average over the elements within each window. estimated_density : 'normal', 'lognor', 'kernel' or None, default None If None (default), ignore. Else, add an estimated density of the given form on top of the histogram by the following estimations: @@ -826,9 +826,7 @@ def plot_one_hist( sd_y = np.std(parameter_chain) if smooth: - if sigma is None: - sigma = sd_y / 2 - counts = gaussian_filter1d(counts, sigma) + counts = moving_average(counts, window_size=window_size) axes.hist( bins[:-1], bins, @@ -1233,3 +1231,30 @@ def plot_bayes(project: ratapi.Project, results: ratapi.outputs.BayesResults): plot_corner(results) else: raise ValueError("Bayes plots are only available for the results of Bayesian analysis (NS or DREAM)") + + +def moving_average(data: np.ndarray, window_size: int = 8) -> list[float]: + """Calculate the moving average of an array with a given window size. + + This is a python equivalent to MATLABs smoothdata(A, 'movmean') + + Parameters + ---------- + data : np.ndarray + The input array to smooth + window_size : int + The window slides down the length of the vector, + computing an average over the elements within each window. + + """ + assert 0 <= window_size <= len(data) + moving_averages = [] + + for i in range(len(data)): + start_window_ind = floor(float(i - window_size / 2)) if i - window_size / 2 > 0 else 0 + end_window_ind = floor(float(i + window_size / 2)) if i + window_size / 2 < len(data) else len(data) + window_average = np.sum(data[start_window_ind:end_window_ind]) / (end_window_ind + 0 - start_window_ind) + moving_averages.append(window_average) + i += 1 + + return moving_averages diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 31049a28..f1721ec1 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -495,3 +495,61 @@ def test_extract_plot_data(data) -> None: with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 0 and 100"): RATplot._extract_plot_data(data, False, True, 100.5) + + +def test_moving_average() -> None: + """Test the moving_average function.""" + data_to_average = np.arange(0, 20) + mov_avg = RATplot.moving_average(data_to_average) + assert mov_avg == [ + 1.5, + 2.0, + 2.5, + 3.0, + 3.5, + 4.5, + 5.5, + 6.5, + 7.5, + 8.5, + 9.5, + 10.5, + 11.5, + 12.5, + 13.5, + 14.5, + 15.5, + 16.0, + 16.5, + 17.0, + ] + + mov_avg = RATplot.moving_average(data_to_average, window_size=2) + assert mov_avg == [ + 0.0, + 0.5, + 1.5, + 2.5, + 3.5, + 4.5, + 5.5, + 6.5, + 7.5, + 8.5, + 9.5, + 10.5, + 11.5, + 12.5, + 13.5, + 14.5, + 15.5, + 16.5, + 17.5, + 18.5, + ] + + with pytest.raises(AssertionError): + RATplot.moving_average(data_to_average, window_size=-1) + + with pytest.raises(AssertionError): + RATplot.moving_average(data_to_average, window_size=len(data_to_average) + 1)