From c71237dec770a217cfb72ddac2616ce2e6cba9a5 Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:21:07 +0100 Subject: [PATCH 1/8] Updates github action for ruff (#173) --- .github/workflows/run_ruff.yml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.github/workflows/run_ruff.yml b/.github/workflows/run_ruff.yml index 5fc5b79f..64637b07 100644 --- a/.github/workflows/run_ruff.yml +++ b/.github/workflows/run_ruff.yml @@ -7,8 +7,6 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: chartboost/ruff-action@v1 - - uses: chartboost/ruff-action@v1 - with: - args: 'format --check' + - uses: astral-sh/ruff-action@v3 + - run: ruff format --check \ No newline at end of file From 45e44ef7831b6513b159e7b11c7bf0168fcc9490 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Wed, 23 Jul 2025 16:04:30 +0100 Subject: [PATCH 2/8] Updates cpp and add Mac ARM to test (#172) --- .github/workflows/run_tests.yml | 18 ++++++++++++++---- cpp/RAT | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index e930df1a..a2de2b66 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -19,13 +19,13 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest, macos-13] + platform: [windows-latest, ubuntu-latest, macos-13, macos-14] version: ["3.10", "3.13"] defaults: run: shell: bash -l {0} - runs-on: ${{ matrix.os }} + runs-on: ${{ matrix.platform}} steps: - uses: actions/checkout@v4 @@ -35,8 +35,8 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.version }} - - name: Install OMP (MacOS) - if: runner.os == 'macOS' + - name: Install OMP (MacOS Intel) + if: matrix.platform == 'macos-13' run: | brew install llvm libomp echo "export CC=/usr/local/opt/llvm/bin/clang" >> ~/.bashrc @@ -45,6 +45,16 @@ jobs: echo "export CXXFLAGS=\"$CXXFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp\"" >> ~/.bashrc source ~/.bashrc + - name: Install OMP (MacOS M1) + if: matrix.platform == 'macos-14' + run: | + brew install llvm libomp + echo "export CC=/opt/homebrew/opt/llvm/bin/clang" >> ~/.bashrc + echo "export CXX=/opt/homebrew/opt/llvm/bin/clang++" >> ~/.bashrc + echo "export CFLAGS=\"$CFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc + echo "export CXXFLAGS=\"$CXXFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc + echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp\"" >> ~/.bashrc + source ~/.bashrc - name: Install OMP (Linux) if: runner.os == 'Linux' run: | diff --git a/cpp/RAT b/cpp/RAT index cf81c8d0..370d46f1 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit cf81c8d00f0d0348cbeb360446362f8381093203 +Subproject commit 370d46f19859eb1ada7b031e6540c7747e328f82 From 50289e3c2a802b813612a4fe1544a5204ec66715 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Tue, 29 Jul 2025 10:02:50 +0100 Subject: [PATCH 3/8] Adds the `shift_value` argument to `plot_ref_sld` (#174) --- ratapi/utils/plotting.py | 206 +++------------------------------------ tests/test_plotting.py | 18 +--- 2 files changed, 20 insertions(+), 204 deletions(-) diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index 38d3d777..d4cfa599 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -21,7 +21,7 @@ from ratapi.rat_core import PlotEventData, makeSLDProfile -def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool): +def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool, shift_value: float): """Extract the plot data for the sld, ref, error plot lines. Parameters @@ -33,6 +33,8 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool Controls whether Q^4 is plotted on the reflectivity plot show_error_bar : bool, default: True Controls whether the error bars are shown + shift_value : float + A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts Returns ------- @@ -42,9 +44,12 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool """ results = {"ref": [], "error": [], "sld": [], "sld_resample": []} + if shift_value < 1 or shift_value > 100: + raise ValueError("Parameter `shift_value` must be between 1 and 100") + for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)): # Calculate the divisor - div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1)) + div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value) q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4 mult = q4_data / div @@ -87,194 +92,6 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool return results -class PlotSLDWithBlitting: - """Create a SLD plot that uses blitting to get faster draws. - - The blit plot stores the background from an - initial draw then updates the foreground (lines and error bars) if the background is not changed. - - Parameters - ---------- - data : PlotEventData - The plot event data that contains all the information - to generate the ref and sld plots - fig : matplotlib.pyplot.figure, optional - The figure class that has two subplots - linear_x : bool, default: False - Controls whether the x-axis on reflectivity plot uses the linear scale - q4 : bool, default: False - Controls whether Q^4 is plotted on the reflectivity plot - show_error_bar : bool, default: True - Controls whether the error bars are shown - show_grid : bool, default: False - Controls whether the grid is shown - show_legend : bool, default: True - Controls whether the legend is shown - """ - - def __init__( - self, - data: PlotEventData, - fig: Optional[matplotlib.pyplot.figure] = None, - linear_x: bool = False, - q4: bool = False, - show_error_bar: bool = True, - show_grid: bool = False, - show_legend: bool = True, - ): - self.figure = fig - self.linear_x = linear_x - self.q4 = q4 - self.show_error_bar = show_error_bar - self.show_grid = show_grid - self.show_legend = show_legend - self.updatePlot(data) - self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent) - - def __del__(self): - self.figure.canvas.mpl_disconnect(self.event_id) - - def resizeEvent(self, _event): - """Ensure the background is updated after a resize event.""" - self.__background_changed = True - - def update(self, data: PlotEventData): - """Update the foreground, if background has not changed otherwise it updates full plot. - - Parameters - ---------- - data : PlotEventData - The plot event data that contains all the information - to generate the ref and sld plots - """ - if self.__background_changed: - self.updatePlot(data) - else: - self.updateForeground(data) - - def __setattr__(self, name, value): - super().__setattr__(name, value) - if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend"]: - self.__background_changed = True - - def setAnimated(self, is_animated: bool): - """Set the animated property of foreground plot elements. - - Parameters - ---------- - is_animated : bool - Indicates if the animated property should been set. - """ - for line in self.figure.axes[0].lines: - line.set_animated(is_animated) - for line in self.figure.axes[1].lines: - line.set_animated(is_animated) - for container in self.figure.axes[0].containers: - container[2][0].set_animated(is_animated) - - def adjustErrorBar(self, error_bar_container, x, y, y_error): - """Adjust the error bar data. - - Parameters - ---------- - error_bar_container : Tuple - Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines) - x : np.ndarray - The shifted data x axis data - y : np.ndarray - The shifted data y axis data - y_error : np.ndarray - The shifted data y axis error data - """ - line, _, (bars_y,) = error_bar_container - - line.set_data(x, y) - x_base = x - y_base = y - - y_error_top = y_base + y_error - y_error_bottom = y_base - y_error - - new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] - bars_y.set_segments(new_segments_y) - - def updatePlot(self, data: PlotEventData): - """Update the full plot. - - Parameters - ---------- - data : PlotEventData - The plot event data that contains all the information - to generate the ref and sld plots - """ - if self.figure is not None: - self.figure.clf() - self.figure = plot_ref_sld_helper( - data, - self.figure, - linear_x=self.linear_x, - q4=self.q4, - show_error_bar=self.show_error_bar, - show_grid=self.show_grid, - show_legend=self.show_legend, - animated=True, - ) - - self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox) - for line in self.figure.axes[0].lines: - self.figure.axes[0].draw_artist(line) - for line in self.figure.axes[1].lines: - self.figure.axes[1].draw_artist(line) - for container in self.figure.axes[0].containers: - self.figure.axes[0].draw_artist(container[2][0]) - self.figure.canvas.blit(self.figure.bbox) - self.setAnimated(False) - self.__background_changed = False - - def updateForeground(self, data: PlotEventData): - """Update the plot foreground only. - - Parameters - ---------- - data : PlotEventData - The plot event data that contains all the information - to generate the ref and sld plots - """ - self.setAnimated(True) - self.figure.canvas.restore_region(self.bg) - plot_data = _extract_plot_data(data, self.q4, self.show_error_bar) - - offset = 2 if self.show_error_bar else 1 - for i in range( - 0, - len(self.figure.axes[0].lines), - ): - self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1]) - self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i]) - - i = 0 - for j in range(len(plot_data["sld"])): - for sld in plot_data["sld"][j]: - self.figure.axes[1].lines[i].set_data(sld[0], sld[1]) - self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) - i += 1 - - if plot_data["sld_resample"]: - for resampled in plot_data["sld_resample"][j]: - self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1]) - self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) - i += 1 - - for i, container in enumerate(self.figure.axes[0].containers): - self.adjustErrorBar(container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2]) - self.figure.axes[0].draw_artist(container[2][0]) - self.figure.axes[0].draw_artist(container[0]) - - self.figure.canvas.blit(self.figure.bbox) - self.figure.canvas.flush_events() - self.setAnimated(False) - - def plot_ref_sld_helper( data: PlotEventData, fig: Optional[matplotlib.pyplot.figure] = None, @@ -285,6 +102,7 @@ def plot_ref_sld_helper( show_error_bar: bool = True, show_grid: bool = False, show_legend: bool = True, + shift_value: float = 100, animated=False, ): """Clear the previous plots and updates the ref and SLD plots. @@ -311,6 +129,8 @@ def plot_ref_sld_helper( Controls whether the grid is shown show_legend : bool, default: True Controls whether the legend is shown + shift_value : float, default: 100 + A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts animated : bool, default: False Controls whether the animated property of foreground plot elements should be set. @@ -339,7 +159,7 @@ def plot_ref_sld_helper( ref_plot.cla() sld_plot.cla() - plot_data = _extract_plot_data(data, q4, show_error_bar) + plot_data = _extract_plot_data(data, q4, show_error_bar, shift_value) for i, name in enumerate(data.contrastNames): ref_plot.plot(plot_data["ref"][i][0], plot_data["ref"][i][1], label=name, linewidth=1, animated=animated) color = ref_plot.get_lines()[-1].get_color() @@ -427,6 +247,7 @@ def plot_ref_sld( show_error_bar: bool = True, show_grid: bool = False, show_legend: bool = True, + shift_value: float = 100, ) -> Union[plt.Figure, None]: """Plot the reflectivity and SLD profiles. @@ -454,6 +275,8 @@ def plot_ref_sld( Controls whether the grid is shown show_legend : bool, default: True Controls whether the legend is shown + shift_value : float, default: 100 + A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts Returns ------- @@ -524,6 +347,7 @@ def plot_ref_sld( show_error_bar=show_error_bar, show_grid=show_grid, show_legend=show_legend, + shift_value=shift_value, ) if return_fig: diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 3bb3f69c..c42bfeea 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -481,20 +481,12 @@ def test_bayes_validation(input_project, reflectivity_calculation_results): @pytest.mark.parametrize("data", [data(), domains_data()]) def test_extract_plot_data(data) -> None: - plot_data = RATplot._extract_plot_data(data, False, True) + plot_data = RATplot._extract_plot_data(data, False, True, 50) assert len(plot_data["ref"]) == len(data.reflectivity) assert len(plot_data["sld"]) == len(data.shiftedData) + with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"): + RATplot._extract_plot_data(data, False, True, 0) -@patch("ratapi.utils.plotting.plot_ref_sld_helper") -def test_blit_plot(plot_helper, fig: plt.figure) -> None: - plot_helper.return_value = fig - event_data = data() - new_plot = RATplot.PlotSLDWithBlitting(event_data) - assert plot_helper.call_count == 1 - new_plot.update(event_data) - assert plot_helper.call_count == 1 # foreground only is updated so no call to plot helper - new_plot.show_grid = False - new_plot.figure = plt.subplots(1, 2)[0] - new_plot.update(event_data) # plot properties have changed so update should call plot_helper - assert plot_helper.call_count == 2 + with pytest.raises(ValueError, match=r"Parameter `shift_value` must be between 1 and 100"): + RATplot._extract_plot_data(data, False, True, 100.5) From d6c44cefaa1faddb8691188e16eed9aece42e5dc Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Thu, 7 Aug 2025 10:02:15 +0100 Subject: [PATCH 4/8] Removes controls option "calcSLDDuringFit" (#175) * Removes controls option "calcSLDDuringFit" * Addresses review comment * Addresses further review comments --- cpp/RAT | 2 +- cpp/includes/defines.h | 3 -- cpp/rat.cpp | 36 +++++++++---------- ratapi/controls.py | 4 --- .../bayes_benchmark/bayes_benchmark.py | 2 +- .../languages/run_custom_file_languages.py | 1 - ratapi/inputs.py | 3 -- ratapi/utils/enums.py | 8 ++--- setup.py | 6 ++-- tests/test_controls.py | 23 ------------ tests/test_inputs.py | 3 -- 11 files changed, 26 insertions(+), 65 deletions(-) diff --git a/cpp/RAT b/cpp/RAT index 370d46f1..16f3ebef 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit 370d46f19859eb1ada7b031e6540c7747e328f82 +Subproject commit 16f3ebef737f4c85ca601046f7a30ffabee4eb47 diff --git a/cpp/includes/defines.h b/cpp/includes/defines.h index 9e8a5ff1..097ad0eb 100644 --- a/cpp/includes/defines.h +++ b/cpp/includes/defines.h @@ -588,8 +588,6 @@ parallel : str How the calculation should be parallelised (This uses the Parallel Computing Toolbox). Can be 'single', 'contrasts' or 'points'. procedure : str Which procedure RAT should execute. Can be 'calculate', 'simplex', 'de', 'ns', or 'dream'. -calcSldDuringFit : bool - Whether SLD will be calculated during fit (for live plotting etc.) numSimulationPoints : int The number of points used for a reflectivity simulation where no data is present. resampleMinAngle : float @@ -664,7 +662,6 @@ struct Control { real_T nMCMC {}; real_T propScale {}; real_T nsTolerance {}; - boolean_T calcSldDuringFit {}; real_T numSimulationPoints {}; real_T resampleMinAngle {}; real_T resampleNPoints {}; diff --git a/cpp/rat.cpp b/cpp/rat.cpp index ddec2782..31bbe144 100644 --- a/cpp/rat.cpp +++ b/cpp/rat.cpp @@ -317,7 +317,6 @@ RAT::Controls createControlsStruct(const Control& control) control_struct.nMCMC = control.nMCMC; control_struct.propScale = control.propScale; control_struct.nsTolerance = control.nsTolerance; - control_struct.calcSldDuringFit = control.calcSldDuringFit; control_struct.numSimulationPoints = control.numSimulationPoints; control_struct.updateFreq = control.updateFreq; control_struct.updatePlotFreq = control.updatePlotFreq; @@ -333,6 +332,7 @@ RAT::Controls createControlsStruct(const Control& control) control_struct.resampleNPoints = control.resampleNPoints; stringToRatBoundedArray(control.boundHandling, control_struct.boundHandling.data, control_struct.boundHandling.size); control_struct.adaptPCR = control.adaptPCR; + control_struct.calcSLD = false; stringToRatBoundedArray(control.IPCFilePath, control_struct.IPCFilePath.data, control_struct.IPCFilePath.size); return control_struct; @@ -910,7 +910,6 @@ PYBIND11_MODULE(rat_core, m) { .def_readwrite("nMCMC", &Control::nMCMC) .def_readwrite("propScale", &Control::propScale) .def_readwrite("nsTolerance", &Control::nsTolerance) - .def_readwrite("calcSldDuringFit", &Control::calcSldDuringFit) .def_readwrite("numSimulationPoints", &Control::numSimulationPoints) .def_readwrite("resampleMinAngle", &Control::resampleMinAngle) .def_readwrite("resampleNPoints", &Control::resampleNPoints) @@ -929,12 +928,12 @@ PYBIND11_MODULE(rat_core, m) { return py::make_tuple(ctrl.parallel, ctrl.procedure, ctrl.display, ctrl.xTolerance, ctrl.funcTolerance, ctrl.maxFuncEvals, ctrl.maxIterations, ctrl.populationSize, ctrl.fWeight, ctrl.crossoverProbability, ctrl.targetValue, ctrl.numGenerations, ctrl.strategy, ctrl.nLive, ctrl.nMCMC, ctrl.propScale, - ctrl.nsTolerance, ctrl.calcSldDuringFit, ctrl.numSimulationPoints, ctrl.resampleMinAngle, ctrl.resampleNPoints, - ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, ctrl.pUnitGamma, - ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath); + ctrl.nsTolerance, ctrl.numSimulationPoints, ctrl.resampleMinAngle, ctrl.resampleNPoints, + ctrl.updateFreq, ctrl.updatePlotFreq, ctrl.nSamples, ctrl.nChains, ctrl.jumpProbability, + ctrl.pUnitGamma, ctrl.boundHandling, ctrl.adaptPCR, ctrl.IPCFilePath); }, [](py::tuple t) { // __setstate__ - if (t.size() != 30) + if (t.size() != 29) throw std::runtime_error("Encountered invalid state unpickling ProblemDefinition object!"); /* Create a new C++ instance */ @@ -957,19 +956,18 @@ PYBIND11_MODULE(rat_core, m) { ctrl.nMCMC = t[14].cast(); ctrl.propScale = t[15].cast(); ctrl.nsTolerance = t[16].cast(); - ctrl.calcSldDuringFit = t[17].cast(); - ctrl.numSimulationPoints = t[18].cast(); - ctrl.resampleMinAngle = t[19].cast(); - ctrl.resampleNPoints = t[20].cast(); - ctrl.updateFreq = t[21].cast(); - ctrl.updatePlotFreq = t[22].cast(); - ctrl.nSamples = t[23].cast(); - ctrl.nChains = t[24].cast(); - ctrl.jumpProbability = t[25].cast(); - ctrl.pUnitGamma = t[26].cast(); - ctrl.boundHandling = t[27].cast(); - ctrl.adaptPCR = t[28].cast(); - ctrl.IPCFilePath = t[29].cast(); + ctrl.numSimulationPoints = t[17].cast(); + ctrl.resampleMinAngle = t[18].cast(); + ctrl.resampleNPoints = t[19].cast(); + ctrl.updateFreq = t[20].cast(); + ctrl.updatePlotFreq = t[21].cast(); + ctrl.nSamples = t[22].cast(); + ctrl.nChains = t[23].cast(); + ctrl.jumpProbability = t[24].cast(); + ctrl.pUnitGamma = t[25].cast(); + ctrl.boundHandling = t[26].cast(); + ctrl.adaptPCR = t[27].cast(); + ctrl.IPCFilePath = t[28].cast(); return ctrl; })); diff --git a/ratapi/controls.py b/ratapi/controls.py index 917d78eb..c5408411 100644 --- a/ratapi/controls.py +++ b/ratapi/controls.py @@ -23,7 +23,6 @@ common_fields = [ "procedure", "parallel", - "calcSldDuringFit", "numSimulationPoints", "resampleMinAngle", "resampleNPoints", @@ -58,9 +57,6 @@ class Controls(BaseModel, validate_assignment=True, extra="forbid", use_attribut parallel: Parallel = Parallel.Single """How the calculation should be parallelised. Can be 'single', 'contrasts' or 'points'.""" - calcSldDuringFit: bool = False - """Whether SLD will be calculated during fit (for live plotting etc.)""" - numSimulationPoints: int = Field(500, ge=2) """The number of points used for reflectivity simulations where no data is supplied.""" diff --git a/ratapi/examples/bayes_benchmark/bayes_benchmark.py b/ratapi/examples/bayes_benchmark/bayes_benchmark.py index 9840b002..ee6ea96e 100644 --- a/ratapi/examples/bayes_benchmark/bayes_benchmark.py +++ b/ratapi/examples/bayes_benchmark/bayes_benchmark.py @@ -250,7 +250,7 @@ def bayes_benchmark_3d(grid_size: int) -> (RAT.outputs.BayesResults, Calculation scale_param = problem.scalefactors[0] scalefactor = np.linspace(scale_param.min, scale_param.max, grid_size) - controls = RAT.Controls(procedure="calculate", calcSldDuringFit=True, display="off") + controls = RAT.Controls(procedure="calculate", display="off") def calculate_posterior(roughness_index: int, background_index: int, scalefactor_index: int) -> float: """Calculate the posterior for an item in the roughness, background, and scalefactor vectors. diff --git a/ratapi/examples/languages/run_custom_file_languages.py b/ratapi/examples/languages/run_custom_file_languages.py index 6e69f80e..2f6025ae 100644 --- a/ratapi/examples/languages/run_custom_file_languages.py +++ b/ratapi/examples/languages/run_custom_file_languages.py @@ -11,7 +11,6 @@ project = setup_problem.make_example_problem() controls = RAT.Controls() -controls.calcSldDuringFit = True # Python start = time.time() diff --git a/ratapi/inputs.py b/ratapi/inputs.py index 34b695d3..890066f9 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -552,7 +552,6 @@ def make_controls(input_controls: ratapi.Controls) -> Control: controls.procedure = input_controls.procedure controls.parallel = input_controls.parallel - controls.calcSldDuringFit = input_controls.calcSldDuringFit controls.numSimulationPoints = input_controls.numSimulationPoints controls.resampleMinAngle = input_controls.resampleMinAngle controls.resampleNPoints = input_controls.resampleNPoints @@ -583,8 +582,6 @@ def make_controls(input_controls: ratapi.Controls) -> Control: controls.pUnitGamma = input_controls.pUnitGamma controls.boundHandling = input_controls.boundHandling controls.adaptPCR = input_controls.adaptPCR - # IPC - controls.IPCFilePath = "" controls.IPCFilePath = input_controls._IPCFilePath diff --git a/ratapi/utils/enums.py b/ratapi/utils/enums.py index 43272355..24c50cbc 100644 --- a/ratapi/utils/enums.py +++ b/ratapi/utils/enums.py @@ -73,22 +73,22 @@ class Strategies(RATEnum): """The base vector is random.""" LocalToBest = "local to best" - """The base vector is a combination of one randomly-selected local solution + """The base vector is a combination of one randomly-selected local solution and the best solution of the previous iteration.""" BestWithJitter = "best jitter" """The base vector is the best solution of the previous iteration, with a small random perturbation applied.""" RandomWithPerVectorDither = "vector dither" - """The base vector is random, with a random scaling factor applied to each mutant. + """The base vector is random, with a random scaling factor applied to each mutant. This scaling factor is different for each mutant.""" RandomWithPerGenerationDither = "generation dither" - """The base vector is random, with a random scaling factor applied to each mutant. + """The base vector is random, with a random scaling factor applied to each mutant. This scaling factor is the same for every mutant, and randomised every generation.""" RandomEitherOrAlgorithm = "either or" - """The base vector is randomly chosen from either a pure random mutation, + """The base vector is randomly chosen from either a pure random mutation, or a pure recombination of parent parameter values.""" @classmethod diff --git a/setup.py b/setup.py index 30763b21..bda3a43a 100644 --- a/setup.py +++ b/setup.py @@ -61,8 +61,8 @@ class BuildExt(build_ext): """A custom build extension for adding compiler-specific options.""" c_opts = { - "msvc": ["/EHsc"], - "unix": ["-fopenmp", "-std=c++11"], + "msvc": ["/O2", "/EHsc"], + "unix": ["-O2", "-fopenmp", "-std=c++11"], } l_opts = { "msvc": [], @@ -71,7 +71,7 @@ class BuildExt(build_ext): if sys.platform == "darwin": darwin_opts = ["-stdlib=libc++", "-mmacosx-version-min=10.9"] - c_opts["unix"] = [*darwin_opts, "-fopenmp"] + c_opts["unix"] = [*darwin_opts, "-fopenmp", "-O2"] l_opts["unix"] = [*darwin_opts, "-lomp"] def build_extensions(self): diff --git a/tests/test_controls.py b/tests/test_controls.py index 3e668795..72f0c745 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -60,7 +60,6 @@ def table_str(self): "+---------------------+-----------+\n" "| procedure | calculate |\n" "| parallel | single |\n" - "| calcSldDuringFit | False |\n" "| numSimulationPoints | 500 |\n" "| resampleMinAngle | 0.9 |\n" "| resampleNPoints | 50 |\n" @@ -74,7 +73,6 @@ def table_str(self): "control_property, value", [ ("parallel", Parallel.Single), - ("calcSldDuringFit", False), ("numSimulationPoints", 500), ("resampleMinAngle", 0.9), ("resampleNPoints", 50), @@ -90,7 +88,6 @@ def test_calculate_property_values(self, control_property: str, value: Any) -> N "control_property, value", [ ("parallel", Parallel.Points), - ("calcSldDuringFit", True), ("numSimulationPoints", 10), ("resampleMinAngle", 0.2), ("resampleNPoints", 1), @@ -186,14 +183,6 @@ def test_calculate_parallel_validation(self, value: Any) -> None: with pytest.raises(pydantic.ValidationError, match="Input should be 'single', 'points' or 'contrasts'"): self.calculate.parallel = value - @pytest.mark.parametrize("value", [5.0, 12]) - def test_calculate_calcSldDuringFit_validation(self, value: Union[int, float]) -> None: - """Tests the calcSldDuringFit setter validation in Calculate class.""" - with pytest.raises( - pydantic.ValidationError, match="Input should be a valid boolean, unable to interpret input" - ): - self.calculate.calcSldDuringFit = value - @pytest.mark.parametrize("value", ["test", "iterate", True, 1, 3.0]) def test_calculate_display_validation(self, value: Any) -> None: """Tests the display setter validation in Calculate class.""" @@ -220,7 +209,6 @@ def table_str(self): "+---------------------+---------+\n" "| procedure | simplex |\n" "| parallel | single |\n" - "| calcSldDuringFit | False |\n" "| numSimulationPoints | 500 |\n" "| resampleMinAngle | 0.9 |\n" "| resampleNPoints | 50 |\n" @@ -240,7 +228,6 @@ def table_str(self): "control_property, value", [ ("parallel", Parallel.Single), - ("calcSldDuringFit", False), ("numSimulationPoints", 500), ("resampleMinAngle", 0.9), ("resampleNPoints", 50), @@ -262,7 +249,6 @@ def test_simplex_property_values(self, control_property: str, value: Any) -> Non "control_property, value", [ ("parallel", Parallel.Points), - ("calcSldDuringFit", True), ("numSimulationPoints", 10), ("resampleMinAngle", 0.2), ("resampleNPoints", 1), @@ -380,7 +366,6 @@ def table_str(self): "+----------------------+---------------+\n" "| procedure | de |\n" "| parallel | single |\n" - "| calcSldDuringFit | False |\n" "| numSimulationPoints | 500 |\n" "| resampleMinAngle | 0.9 |\n" "| resampleNPoints | 50 |\n" @@ -402,7 +387,6 @@ def table_str(self): "control_property, value", [ ("parallel", Parallel.Single), - ("calcSldDuringFit", False), ("numSimulationPoints", 500), ("resampleMinAngle", 0.9), ("resampleNPoints", 50), @@ -424,7 +408,6 @@ def test_de_property_values(self, control_property: str, value: Any) -> None: "control_property, value", [ ("parallel", Parallel.Points), - ("calcSldDuringFit", True), ("numSimulationPoints", 10), ("resampleMinAngle", 0.2), ("resampleNPoints", 1), @@ -556,7 +539,6 @@ def table_str(self): "+---------------------+--------+\n" "| procedure | ns |\n" "| parallel | single |\n" - "| calcSldDuringFit | False |\n" "| numSimulationPoints | 500 |\n" "| resampleMinAngle | 0.9 |\n" "| resampleNPoints | 50 |\n" @@ -574,7 +556,6 @@ def table_str(self): "control_property, value", [ ("parallel", Parallel.Single), - ("calcSldDuringFit", False), ("numSimulationPoints", 500), ("resampleMinAngle", 0.9), ("resampleNPoints", 50), @@ -594,7 +575,6 @@ def test_ns_property_values(self, control_property: str, value: Any) -> None: "control_property, value", [ ("parallel", Parallel.Points), - ("calcSldDuringFit", True), ("numSimulationPoints", 10), ("resampleMinAngle", 0.2), ("resampleNPoints", 1), @@ -725,7 +705,6 @@ def table_str(self): "+---------------------+---------+\n" "| procedure | dream |\n" "| parallel | single |\n" - "| calcSldDuringFit | False |\n" "| numSimulationPoints | 500 |\n" "| resampleMinAngle | 0.9 |\n" "| resampleNPoints | 50 |\n" @@ -745,7 +724,6 @@ def table_str(self): "control_property, value", [ ("parallel", Parallel.Single), - ("calcSldDuringFit", False), ("numSimulationPoints", 500), ("resampleMinAngle", 0.9), ("resampleNPoints", 50), @@ -767,7 +745,6 @@ def test_dream_property_values(self, control_property: str, value: Any) -> None: "control_property, value", [ ("parallel", Parallel.Points), - ("calcSldDuringFit", True), ("numSimulationPoints", 10), ("resampleMinAngle", 0.2), ("resampleNPoints", 1), diff --git a/tests/test_inputs.py b/tests/test_inputs.py index 2c207928..db5b6646 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -359,7 +359,6 @@ def standard_layers_controls(): controls = Control() controls.procedure = Procedures.Calculate controls.parallel = Parallel.Single - controls.calcSldDuringFit = False controls.numSimulationPoints = 500 controls.resampleMinAngle = 0.9 controls.resampleNPoints = 50 @@ -398,7 +397,6 @@ def custom_xy_controls(): controls = Control() controls.procedure = Procedures.Calculate controls.parallel = Parallel.Single - controls.calcSldDuringFit = False controls.numSimulationPoints = 500 controls.resampleMinAngle = 0.9 controls.resampleNPoints = 50.0 @@ -757,7 +755,6 @@ def check_controls_equal(actual_controls, expected_controls) -> None: controls_fields = [ "procedure", "parallel", - "calcSldDuringFit", "numSimulationPoints", "resampleMinAngle", "resampleNPoints", From 9ce6e9fd79970f2f0daa6254827a6b10d049d3a2 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Mon, 18 Aug 2025 13:44:08 +0100 Subject: [PATCH 5/8] Fixes incorrect shift on bayes plots (#176) --- ratapi/utils/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index d4cfa599..e61e40ae 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -167,7 +167,7 @@ def plot_ref_sld_helper( # Plot confidence intervals if required if confidence_intervals is not None: # Calculate the divisor - div = 1 if i == 0 and not q4 else 2 ** (4 * (i + 1)) + div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value) ref_min, ref_max = confidence_intervals["reflectivity"][i] mult = (1 if not q4 else plot_data["ref"][i][0] ** 4) / div ref_plot.fill_between(plot_data["ref"][i][0], ref_min * mult, ref_max * mult, alpha=0.6, color="grey") From a57541af1fbc7f5f793f7e4b27a929ce1fada6a3 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Mon, 1 Sep 2025 16:22:42 +0100 Subject: [PATCH 6/8] Reverts removal of bliiting (#177) --- ratapi/utils/plotting.py | 205 ++++++++++++++++++++++++++++++++++++++- tests/test_orso_utils.py | 3 + 2 files changed, 207 insertions(+), 1 deletion(-) diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index e61e40ae..a6f2f557 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -356,6 +356,205 @@ def plot_ref_sld( plt.show(block=block) +class BlittingSupport: + """Create a SLD plot that uses blitting to get faster draws. + + The blit plot stores the background from an + initial draw then updates the foreground (lines and error bars) if the background is not changed. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + fig : matplotlib.pyplot.figure, optional + The figure class that has two subplots + linear_x : bool, default: False + Controls whether the x-axis on reflectivity plot uses the linear scale + q4 : bool, default: False + Controls whether Q^4 is plotted on the reflectivity plot + show_error_bar : bool, default: True + Controls whether the error bars are shown + show_grid : bool, default: False + Controls whether the grid is shown + show_legend : bool, default: True + Controls whether the legend is shown + shift_value : float, default: 100 + A value between 1 and 100 that controls the spacing between the reflectivity plots for each of the contrasts + """ + + def __init__( + self, + data, + fig=None, + linear_x: bool = False, + q4: bool = False, + show_error_bar: bool = True, + show_grid: bool = False, + show_legend: bool = True, + shift_value: float = 100, + ): + self.figure = fig + self.linear_x = linear_x + self.q4 = q4 + self.show_error_bar = show_error_bar + self.show_grid = show_grid + self.show_legend = show_legend + self.shift_value = shift_value + self.update_plot(data) + self.event_id = self.figure.canvas.mpl_connect("resize_event", self.resizeEvent) + + def __del__(self): + self.figure.canvas.mpl_disconnect(self.event_id) + + def resizeEvent(self, _event): + """Ensure the background is updated after a resize event.""" + self.__background_changed = True + + def update(self, data): + """Update the foreground, if background has not changed otherwise it updates full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.__background_changed: + self.update_plot(data) + else: + self.update_foreground(data) + + def __setattr__(self, name, value): + old_value = getattr(self, name, None) + if value == old_value: + return + + super().__setattr__(name, value) + if name in ["figure", "linear_x", "q4", "show_error_bar", "show_grid", "show_legend", "shift_value"]: + self.__background_changed = True + + def set_animated(self, is_animated: bool): + """Set the animated property of foreground plot elements. + + Parameters + ---------- + is_animated : bool + Indicates if the animated property should be set. + """ + for line in self.figure.axes[0].lines: + line.set_animated(is_animated) + for line in self.figure.axes[1].lines: + line.set_animated(is_animated) + for container in self.figure.axes[0].containers: + container[2][0].set_animated(is_animated) + + def adjust_error_bar(self, error_bar_container, x, y, y_error): + """Adjust the error bar data. + + Parameters + ---------- + error_bar_container : Tuple + Tuple containing the artist of the errorbar i.e. (data line, cap lines, bar lines) + x : np.ndarray + The shifted data x axis data + y : np.ndarray + The shifted data y axis data + y_error : np.ndarray + The shifted data y axis error data + """ + line, _, (bars_y,) = error_bar_container + + line.set_data(x, y) + x_base = x + y_base = y + + y_error_top = y_base + y_error + y_error_bottom = y_base - y_error + + new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] + bars_y.set_segments(new_segments_y) + + def update_plot(self, data): + """Update the full plot. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + if self.figure is not None: + self.figure.clf() + self.figure = ratapi.plotting.plot_ref_sld_helper( + data, + self.figure, + linear_x=self.linear_x, + q4=self.q4, + show_error_bar=self.show_error_bar, + show_grid=self.show_grid, + show_legend=self.show_legend, + animated=True, + ) + self.figure.tight_layout(pad=1) + self.figure.canvas.draw() + self.bg = self.figure.canvas.copy_from_bbox(self.figure.bbox) + for line in self.figure.axes[0].lines: + self.figure.axes[0].draw_artist(line) + for line in self.figure.axes[1].lines: + self.figure.axes[1].draw_artist(line) + for container in self.figure.axes[0].containers: + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.canvas.blit(self.figure.bbox) + self.set_animated(False) + self.__background_changed = False + + def update_foreground(self, data): + """Update the plot foreground only. + + Parameters + ---------- + data : PlotEventData + The plot event data that contains all the information + to generate the ref and sld plots + """ + self.set_animated(True) + self.figure.canvas.restore_region(self.bg) + plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) + + offset = 2 if self.show_error_bar else 1 + for i in range( + 0, + len(self.figure.axes[0].lines), + ): + self.figure.axes[0].lines[i].set_data(plot_data["ref"][i // offset][0], plot_data["ref"][i // offset][1]) + self.figure.axes[0].draw_artist(self.figure.axes[0].lines[i]) + + i = 0 + for j in range(len(plot_data["sld"])): + for sld in plot_data["sld"][j]: + self.figure.axes[1].lines[i].set_data(sld[0], sld[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + if plot_data["sld_resample"]: + for resampled in plot_data["sld_resample"][j]: + self.figure.axes[1].lines[i].set_data(resampled[0], resampled[1]) + self.figure.axes[1].draw_artist(self.figure.axes[1].lines[i]) + i += 1 + + for i, container in enumerate(self.figure.axes[0].containers): + self.adjust_error_bar( + container, plot_data["error"][i][0], plot_data["error"][i][1], plot_data["error"][i][2] + ) + self.figure.axes[0].draw_artist(container[2][0]) + self.figure.axes[0].draw_artist(container[0]) + + self.figure.canvas.blit(self.figure.bbox) + self.figure.canvas.flush_events() + self.set_animated(False) + + class LivePlot: """Create a plot that gets updates from the plot event during a calculation. @@ -369,6 +568,7 @@ class LivePlot: def __init__(self, block=False): self.block = block self.closed = False + self.blit_plot = None def __enter__(self): self.figure = plt.subplots(1, 2)[0] @@ -394,7 +594,10 @@ def plotEvent(self, event): """ if not self.closed and self.figure.number in plt.get_fignums(): - plot_ref_sld_helper(event, self.figure) + if self.blit_plot is None: + self.blit_plot = BlittingSupport(event, self.figure) + else: + self.blit_plot.update(event) def __exit__(self, _exc_type, _exc_val, _traceback): ratapi.events.clear(ratapi.events.EventTypes.Plot, self.plotEvent) diff --git a/tests/test_orso_utils.py b/tests/test_orso_utils.py index 39137fe9..a8c0791a 100644 --- a/tests/test_orso_utils.py +++ b/tests/test_orso_utils.py @@ -36,6 +36,7 @@ def prist(): ], ) @pytest.mark.parametrize("absorption", [True, False]) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_orso_model_to_rat(model, absorption): """Test that orso_model_to_rat gives the expected parameters, layers and model.""" @@ -72,6 +73,7 @@ def test_orso_model_to_rat(model, absorption): "prist5_10K_m_025.Rqz.ort", ], ) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_load_ort_data(test_data): """Test that .ort data is loaded correctly.""" # manually get the test data for comparison @@ -104,6 +106,7 @@ def test_load_ort_data(test_data): ["prist5_10K_m_025.Rqz.ort", "prist.json"], ], ) +@pytest.mark.skip(reason="orsopy database website (https://slddb.esss.dk/slddb/) is not available") def test_load_ort_project(test_data, expected_data): """Test that a project with model data is loaded correctly.""" ort_data = ORSOProject(Path(TEST_DIR_PATH, test_data)) From cdaa93d10bd36628ff0f47e5721e6e860d1baada Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Mon, 1 Sep 2025 17:13:34 +0100 Subject: [PATCH 7/8] Adds "repeat_layers" option to contrasts (#178) * Adds validator to warn for resample=False for custom XY * Adds "repeat_layers" option to contrasts * Reverts changes to test data --- cpp/RAT | 2 +- ratapi/inputs.py | 2 +- ratapi/models.py | 9 +++++++++ ratapi/project.py | 31 +++++++++++++++++++++++++++++-- tests/test_inputs.py | 2 +- tests/test_project.py | 40 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 81 insertions(+), 5 deletions(-) diff --git a/cpp/RAT b/cpp/RAT index 16f3ebef..aae3dc14 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit 16f3ebef737f4c85ca601046f7a30ffabee4eb47 +Subproject commit aae3dc141b6a10c6e10dfb47cd62e07a2a11857d diff --git a/ratapi/inputs.py b/ratapi/inputs.py index 890066f9..7d0a8720 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -302,7 +302,7 @@ def make_problem(project: ratapi.Project) -> ProblemDefinition: problem.numberOfContrasts = len(project.contrasts) problem.geometry = project.geometry problem.useImaginary = project.absorption - problem.repeatLayers = [1] * len(project.contrasts) + problem.repeatLayers = [contrast.repeat_layers for contrast in project.contrasts] problem.contrastBackgroundParams = contrast_background_params problem.contrastBackgroundTypes = contrast_background_types problem.contrastBackgroundActions = [contrast.background_action for contrast in project.contrasts] diff --git a/ratapi/models.py b/ratapi/models.py index e882af65..7d0d4f74 100644 --- a/ratapi/models.py +++ b/ratapi/models.py @@ -163,6 +163,8 @@ class Contrast(RATModel): The name of the instrument resolution for this contrast. resample : bool Whether adaptive resampling should be used for interface microslicing. + repeat_layers : int + For standard layers, the number of times the set of layers defined in the model should be repeated. model : list[str] If this is a standard layers model, this should be a list of layer names that make up the slab model for this contrast. @@ -180,6 +182,7 @@ class Contrast(RATModel): scalefactor: str = "" resolution: str = "" resample: bool = False + repeat_layers: int = Field(default=1, gt=0) model: list[str] = [] @model_validator(mode="before") @@ -208,6 +211,7 @@ def __str__(self): self.scalefactor, self.resolution, self.resample, + self.repeat_layers, model_entry, ] ) @@ -238,6 +242,8 @@ class ContrastWithRatio(RATModel): The name of the instrument resolution for this contrast. resample : bool Whether adaptive resampling should be used for interface microslicing. + repeat_layers : int + For standard layers, the number of times the set of layers defined in the model should be repeated. domain_ratio : str The name of the domain ratio parameter describing how the first domain should be weighted relative to the second. @@ -258,6 +264,7 @@ class ContrastWithRatio(RATModel): scalefactor: str = "" resolution: str = "" resample: bool = False + repeat_layers: int = Field(default=1, gt=0) domain_ratio: str = "" model: list[str] = [] @@ -276,6 +283,8 @@ def __str__(self): self.scalefactor, self.resolution, self.resample, + self.repeat_layers, + self.domain_ratio, model_entry, ] ) diff --git a/ratapi/project.py b/ratapi/project.py index 860f23b7..88555539 100644 --- a/ratapi/project.py +++ b/ratapi/project.py @@ -361,8 +361,8 @@ def model_post_init(self, __context: Any) -> None: """Set up the Class to protect against disallowed modification. We initialise the class handle in the ClassLists for empty data fields, set protected parameters, get names of - all defined parameters, determine the contents of the "model" field in contrasts, - and wrap ClassList routines to control revalidation. + all defined parameters, determine the contents of the "model" field in contrasts, and wrap ClassList routines + to control revalidation. """ # Ensure all ClassLists have the correct _class_handle defined for field in (fields := Project.model_fields): @@ -454,6 +454,33 @@ def set_layers(self) -> "Project": self.layers.data = [] return self + @model_validator(mode="after") + def set_repeat_layers(self) -> "Project": + """If we are not using a standard layers model, warn that the repeat layers setting is not valid.""" + if self.model != LayerModels.StandardLayers: + for contrast in self.contrasts: + if "repeat_layers" in contrast.model_fields_set and contrast.repeat_layers != 1: + warnings.warn( + 'For a custom layers or custom XY calculation, the "repeat_layers" setting for each ' + "contrast is not valid - resetting to 1.", + stacklevel=2, + ) + contrast.repeat_layers = 1 + return self + + @model_validator(mode="after") + def set_resample(self) -> "Project": + """If we are using a custom XY model, warn that the resample setting for each contrast must always be True.""" + if self.model == LayerModels.CustomXY: + for contrast in self.contrasts: + if "resample" in contrast.model_fields_set and contrast.resample is False: + warnings.warn( + 'For a custom XY calculation, "resample" must be True for each contrast - resetting to True.', + stacklevel=2, + ) + contrast.resample = True + return self + @model_validator(mode="after") def set_calculation(self) -> "Project": """Apply the calc setting to the project.""" diff --git a/tests/test_inputs.py b/tests/test_inputs.py index db5b6646..e7a1560e 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -306,7 +306,7 @@ def custom_xy_problem(test_names, test_checks): problem.contrastResolutionTypes = ["constant"] problem.contrastCustomFiles = [1] problem.contrastDomainRatios = [0] - problem.resample = [False] + problem.resample = [True] problem.dataPresent = [0] problem.data = [np.empty([0, 6])] problem.dataLimits = [[0.0, 0.0]] diff --git a/tests/test_project.py b/tests/test_project.py index c2ea92a2..7683dec9 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -443,6 +443,46 @@ def test_set_layers(project_parameters: dict) -> None: assert project.layers == [] +@pytest.mark.parametrize( + "project_parameters", + [ + ( + { + "model": LayerModels.CustomLayers, + "contrasts": [ratapi.models.Contrast(name="Test Contrast", repeat_layers=2)], + } + ), + ({"model": LayerModels.CustomXY, "contrasts": [ratapi.models.Contrast(name="Test Contrast", repeat_layers=2)]}), + ], +) +def test_set_repeat_layers(project_parameters: dict) -> None: + """If we are using a custom layers of custom XY model, the "resample" field of all the contrasts should always + be 1.""" + with pytest.warns( + match='For a custom layers or custom XY calculation, the "repeat_layers" setting for each ' + "contrast is not valid - resetting to 1." + ): + project = ratapi.Project(**project_parameters) + assert all(contrast.repeat_layers == 1 for contrast in project.contrasts) + + +@pytest.mark.parametrize( + "project_parameters", + [ + ({"model": LayerModels.CustomXY, "contrasts": [ratapi.models.Contrast(name="Test Contrast")]}), + ], +) +def test_set_resample(project_parameters: dict) -> None: + """If we are using a custom XY model, the "resample" field of all the contrasts should always be True.""" + project = ratapi.Project(**project_parameters) + assert all(contrast.resample for contrast in project.contrasts) + with pytest.warns( + match='For a custom XY calculation, "resample" must be True for each contrast - resetting to True.' + ): + project.contrasts.append(name="New Contrast", resample=False) + assert all(contrast.resample for contrast in project.contrasts) + + @pytest.mark.parametrize( ["input_calculation", "input_contrast", "new_calculation", "new_contrast_model", "num_domain_ratios"], [ From f978f4c7ae3f0ec13ace6b0b7d2cf0cfe0550184 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Tue, 2 Sep 2025 09:00:57 +0100 Subject: [PATCH 8/8] Updates version to 0.0.0.dev8 (#179) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index bda3a43a..b5871644 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from setuptools.command.build_clib import build_clib from setuptools.command.build_ext import build_ext -__version__ = "0.0.0.dev7" +__version__ = "0.0.0.dev8" PACKAGE_NAME = "ratapi" with open("README.md") as f: