From 929b7149494414ac8a0249557880ed85afd138c0 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:03:49 +0100 Subject: [PATCH 1/4] Unify save methods, refactor plots, enable OpenMP windows(#180) --- .github/workflows/build_wheel.yml | 14 +-- .github/workflows/run_tests.yml | 14 +-- pyproject.toml | 3 +- ratapi/controls.py | 13 +-- ratapi/examples/domains/domains_XY_model.py | 20 ++-- .../normal_reflectivity/custom_XY_DSPC.py | 26 +++-- ratapi/utils/plotting.py | 96 +++++++++++++------ setup.py | 2 +- tests/test_controls.py | 25 +++++ tests/test_plotting.py | 17 ++-- 10 files changed, 143 insertions(+), 87 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 7f0205b5..3fa3f319 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: true matrix: - platform: [windows-latest, ubuntu-latest, macos-13, macos-14] + platform: [windows-2022, ubuntu-latest, macos-13, macos-14] env: CIBW_SKIP: 'pp*' CIBW_ARCHS: 'auto64' @@ -34,9 +34,9 @@ jobs: - name: Install OMP (MacOS Intel) if: matrix.platform == 'macos-13' run: | - brew install llvm libomp - echo "export CC=/usr/local/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/usr/local/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/usr/local/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/usr/local/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp\"" >> ~/.bashrc @@ -44,9 +44,9 @@ jobs: - name: Install OMP (MacOS M1) if: matrix.platform == 'macos-14' run: | - brew install llvm libomp - echo "export CC=/opt/homebrew/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/opt/homebrew/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/opt/homebrew/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/opt/homebrew/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp\"" >> ~/.bashrc diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index a2de2b66..6936d3ca 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -19,7 +19,7 @@ jobs: strategy: fail-fast: false matrix: - platform: [windows-latest, ubuntu-latest, macos-13, macos-14] + platform: [windows-2022, ubuntu-latest, macos-13, macos-14] version: ["3.10", "3.13"] defaults: run: @@ -38,9 +38,9 @@ jobs: - name: Install OMP (MacOS Intel) if: matrix.platform == 'macos-13' run: | - brew install llvm libomp - echo "export CC=/usr/local/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/usr/local/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/usr/local/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/usr/local/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/usr/local/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/usr/local/opt/libomp/lib -L/usr/local/opt/libomp/lib -lomp\"" >> ~/.bashrc @@ -48,9 +48,9 @@ jobs: - name: Install OMP (MacOS M1) if: matrix.platform == 'macos-14' run: | - brew install llvm libomp - echo "export CC=/opt/homebrew/opt/llvm/bin/clang" >> ~/.bashrc - echo "export CXX=/opt/homebrew/opt/llvm/bin/clang++" >> ~/.bashrc + brew install llvm@20 libomp + echo "export CC=/opt/homebrew/opt/llvm@20/bin/clang" >> ~/.bashrc + echo "export CXX=/opt/homebrew/opt/llvm@20/bin/clang++" >> ~/.bashrc echo "export CFLAGS=\"$CFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export CXXFLAGS=\"$CXXFLAGS -I/opt/homebrew/opt/libomp/include\"" >> ~/.bashrc echo "export LDFLAGS=\"$LDFLAGS -Wl,-rpath,/opt/homebrew/opt/libomp/lib -L/opt/homebrew/opt/libomp/lib -lomp\"" >> ~/.bashrc diff --git a/pyproject.toml b/pyproject.toml index 35febb57..d52f72a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,4 +38,5 @@ mark-parentheses = false [tool.ruff.lint.pydocstyle] convention = "numpy" - +[tool.ruff.lint.isort] +known-first-party = ["ratapi.rat_core"] diff --git a/ratapi/controls.py b/ratapi/controls.py index c5408411..ea62f5d2 100644 --- a/ratapi/controls.py +++ b/ratapi/controls.py @@ -233,19 +233,16 @@ def delete_IPC(self): os.remove(self._IPCFilePath) return None - def save(self, path: Union[str, Path], filename: str = "controls"): + def save(self, filepath: Union[str, Path] = "./controls.json"): """Save a controls object to a JSON file. Parameters ---------- - path : str or Path - The directory in which the controls object will be written. - filename : str - The name for the JSON file containing the controls object. - + filepath : str or Path + The path to where the controls file will be written. """ - file = Path(path, f"{filename.removesuffix('.json')}.json") - file.write_text(self.model_dump_json()) + filepath = Path(filepath).with_suffix(".json") + filepath.write_text(self.model_dump_json()) @classmethod def load(cls, path: Union[str, Path]) -> "Controls": diff --git a/ratapi/examples/domains/domains_XY_model.py b/ratapi/examples/domains/domains_XY_model.py index 8aeb8c77..00567666 100644 --- a/ratapi/examples/domains/domains_XY_model.py +++ b/ratapi/examples/domains/domains_XY_model.py @@ -1,8 +1,9 @@ """Custom model file for the domains custom XY example.""" -import math +from math import sqrt import numpy as np +from scipy.special import erf def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): @@ -19,13 +20,13 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): z = np.arange(0, 141) # Make the volume fraction distribution for our Silicon substrate - [vfSilicon, siSurf] = makeLayer(z, -25, 50, 1, subRough, subRough) + [vfSilicon, siSurf] = make_layer(z, -25, 50, 1, subRough, subRough) # ... and the Oxide ... - [vfOxide, oxSurface] = makeLayer(z, siSurf, oxideThick, 1, subRough, subRough) + [vfOxide, oxSurface] = make_layer(z, siSurf, oxideThick, 1, subRough, subRough) # ... and also our layer. - [vfLayer, laySurface] = makeLayer(z, oxSurface, layerThick, 1, subRough, layerRough) + [vfLayer, laySurface] = make_layer(z, oxSurface, layerThick, 1, subRough, layerRough) # Everything that is not already occupied will be filled will water totalVF = vfSilicon + vfOxide + vfLayer @@ -53,7 +54,7 @@ def domains_XY_model(params, bulk_in, bulk_out, contrast, domain): return SLD, subRough -def makeLayer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): +def make_layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): """Produce a layer, with a defined thickness, height and roughness. Each side of the layer has its own roughness value. @@ -63,12 +64,9 @@ def makeLayer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): right = prevLaySurf + thickness # Make our heaviside - a = (z - left) / ((2**0.5) * Sigma_L) - b = (z - right) / ((2**0.5) * Sigma_R) + erf_left = erf((z - left) / (sqrt(2) * Sigma_L)) + erf_right = erf((z - right) / (sqrt(2) * Sigma_R)) - erf_a = np.array([math.erf(value) for value in a]) - erf_b = np.array([math.erf(value) for value in b]) - - VF = np.array((height / 2) * (erf_a - erf_b)) + VF = np.array((0.5 * height) * (erf_left - erf_right)) return VF, right diff --git a/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py b/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py index dc1d1013..93e25b08 100644 --- a/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py +++ b/ratapi/examples/normal_reflectivity/custom_XY_DSPC.py @@ -1,8 +1,9 @@ """A custom XY model for a supported DSPC bilayer.""" -import math +from math import sqrt import numpy as np +from scipy.special import erf def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): @@ -51,10 +52,10 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): z = np.arange(0, 141) # Make our Silicon substrate - vfSilicon, siSurf = layer(z, -25, 50, 1, subRough, subRough) + vfSilicon, siSurf = make_layer(z, -25, 50, 1, subRough, subRough) # Add the Oxide - vfOxide, oxSurface = layer(z, siSurf, oxideThick, 1, subRough, subRough) + vfOxide, oxSurface = make_layer(z, siSurf, oxideThick, 1, subRough, subRough) # We fill in the water at the end, but there may be a hydration layer between the bilayer and the oxide, # so we start the bilayer stack an appropriate distance away @@ -65,15 +66,15 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): headThick = vHead / lipidAPM # ... and make a box for the volume fraction (1 for now, we correct for coverage later) - vfHeadL, headLSurface = layer(z, watSurface, headThick, 1, bilayerRough, bilayerRough) + vfHeadL, headLSurface = make_layer(z, watSurface, headThick, 1, bilayerRough, bilayerRough) # ... also do the same for the tails # We'll make both together, so the thickness will be twice the volume tailsThick = (2 * vTail) / lipidAPM - vfTails, tailsSurf = layer(z, headLSurface, tailsThick, 1, bilayerRough, bilayerRough) + vfTails, tailsSurf = make_layer(z, headLSurface, tailsThick, 1, bilayerRough, bilayerRough) # Finally the upper head ... - vfHeadR, headSurface = layer(z, tailsSurf, headThick, 1, bilayerRough, bilayerRough) + vfHeadR, headSurface = make_layer(z, tailsSurf, headThick, 1, bilayerRough, bilayerRough) # Making the model # We've created the volume fraction profiles corresponding to each of the groups. @@ -114,12 +115,12 @@ def custom_XY_DSPC(params, bulk_in, bulk_out, contrast): totSLD = sldSilicon + sldOxide + sldHeadL + sldTails + sldHeadR + sldWat # Make the SLD array for output - SLD = [[a, b] for (a, b) in zip(z, totSLD)] + SLD = np.column_stack((z, totSLD)) return SLD, subRough -def layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): +def make_layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): """Produce a layer, with a defined thickness, height and roughness. Each side of the layer has its own roughness value. @@ -129,12 +130,9 @@ def layer(z, prevLaySurf, thickness, height, Sigma_L, Sigma_R): right = prevLaySurf + thickness # Make our heaviside - a = (z - left) / ((2**0.5) * Sigma_L) - b = (z - right) / ((2**0.5) * Sigma_R) + erf_left = erf((z - left) / (sqrt(2) * Sigma_L)) + erf_right = erf((z - right) / (sqrt(2) * Sigma_R)) - erf_a = np.array([math.erf(value) for value in a]) - erf_b = np.array([math.erf(value) for value in b]) - - VF = np.array((height / 2) * (erf_a - erf_b)) + VF = np.array((0.5 * height) * (erf_left - erf_right)) return VF, right diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index a6f2f557..c2823b8f 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -94,7 +94,7 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool def plot_ref_sld_helper( data: PlotEventData, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.pyplot.figure, delay: bool = True, confidence_intervals: Union[dict, None] = None, linear_x: bool = False, @@ -112,8 +112,8 @@ def plot_ref_sld_helper( data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure, optional - The figure class that has two subplots + fig : matplotlib.pyplot.figure + The figure object that has two subplots delay : bool, default: True Controls whether to delay 0.005s after plot is created confidence_intervals : dict or None, default None @@ -134,19 +134,13 @@ def plot_ref_sld_helper( animated : bool, default: False Controls whether the animated property of foreground plot elements should be set. - Returns - ------- - fig : matplotlib.pyplot.figure - The figure class that has two subplots - """ preserve_zoom = False - if fig is None: - fig = plt.subplots(1, 2)[0] - elif len(fig.axes) != 2: + if len(fig.axes) != 2: fig.clf() fig.subplots(1, 2) + fig.subplots_adjust(wspace=0.3) ref_plot: plt.Axes = fig.axes[0] @@ -233,13 +227,12 @@ def plot_ref_sld_helper( if delay: plt.pause(0.005) - return fig - def plot_ref_sld( project: ratapi.Project, results: Union[ratapi.outputs.Results, ratapi.outputs.BayesResults], block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, bayes: Literal[65, 95, None] = None, linear_x: bool = False, @@ -259,6 +252,8 @@ def plot_ref_sld( The result from the calculation block : bool, default: False Indicates the plot should block until it is closed + fig : matplotlib.pyplot.figure, optional + The figure object that has two subplots return_fig : bool, default False If True, return the figure instead of displaying it. bayes : 65, 95 or None, default None @@ -336,11 +331,15 @@ def plot_ref_sld( else: confidence_intervals = None - figure = plt.subplots(1, 2)[0] + if fig is None: + fig = plt.subplots(1, 2)[0] + elif len(fig.axes) != 2: + fig.clf() + fig.subplots(1, 2) plot_ref_sld_helper( data, - figure, + fig, confidence_intervals=confidence_intervals, linear_x=linear_x, q4=q4, @@ -351,7 +350,7 @@ def plot_ref_sld( ) if return_fig: - return figure + return fig plt.show(block=block) @@ -486,7 +485,7 @@ def update_plot(self, data): """ if self.figure is not None: self.figure.clf() - self.figure = ratapi.plotting.plot_ref_sld_helper( + plot_ref_sld_helper( data, self.figure, linear_x=self.linear_x, @@ -520,7 +519,7 @@ def update_foreground(self, data): """ self.set_animated(True) self.figure.canvas.restore_region(self.bg) - plot_data = ratapi.plotting._extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) + plot_data = _extract_plot_data(data, self.q4, self.show_error_bar, self.shift_value) offset = 2 if self.show_error_bar else 1 for i in range( @@ -649,9 +648,11 @@ def plot_corner( params: Union[list[Union[int, str]], None] = None, smooth: bool = True, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, hist_kwargs: Union[dict, None] = None, hist2d_kwargs: Union[dict, None] = None, + progress_callback: Union[Callable[[int, int], None], None] = None, ): """Create a corner plot from a Bayesian analysis. @@ -666,6 +667,8 @@ def plot_corner( Whether to apply Gaussian smoothing to the corner plot. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. hist_kwargs : dict @@ -674,6 +677,9 @@ def plot_corner( hist2d_kwargs : dict Extra keyword arguments to pass to the 2d histograms. Default is {'density': True, 'bins': 25} + progress_callback: Union[Callable[[int, int], None], None] + Callback function for providing progress during plot creation + First argument is current completed sub plot and second is total number of sub plots Returns ------- @@ -695,24 +701,32 @@ def plot_corner( hist2d_kwargs = {} num_params = len(params) + total_count = num_params + (num_params**2 - num_params) // 2 + + if fig is None: + fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10), subplot_kw={"visible": False}) + else: + fig.clf() + axes = fig.subplots(num_params, num_params, subplot_kw={"visible": False}) - fig, axes = plt.subplots(num_params, num_params, figsize=(11, 10)) # i is row, j is column - for i, row_param in enumerate(params): - for j, col_param in enumerate(params): - current_axes: Axes = axes[i][j] + current_count = 0 + for i in range(num_params): + for j in range(i + 1): + row_param = params[i] + col_param = params[j] + current_axes: Axes = axes if isinstance(axes, matplotlib.axes.Axes) else axes[i][j] current_axes.tick_params(which="both", labelsize="medium") current_axes.xaxis.offsetText.set_fontsize("small") current_axes.yaxis.offsetText.set_fontsize("small") - + current_axes.set_visible(True) if i == j: # diagonal: histograms plot_one_hist(results, param=row_param, smooth=smooth, axes=current_axes, **hist_kwargs) elif i > j: # lower triangle: 2d histograms plot_contour( results, x_param=col_param, y_param=row_param, smooth=smooth, axes=current_axes, **hist2d_kwargs ) - elif i < j: # upper triangle: no plot - current_axes.set_visible(False) + # remove label if on inside of corner plot if j != 0: current_axes.get_yaxis().set_visible(False) @@ -725,6 +739,9 @@ def plot_corner( current_axes.yaxis.offset_text_position = "center" current_axes.set_ylabel("") current_axes.set_xlabel("") + if progress_callback is not None: + current_count += 1 + progress_callback(current_count, total_count) if return_fig: return fig plt.show(block=block) @@ -956,7 +973,9 @@ def plot_contour( plt.show(block=block) -def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.figure.Figure: +def panel_plot_helper( + plot_func: Callable, indices: list[int], fig: Optional[matplotlib.pyplot.figure] = None +) -> matplotlib.figure.Figure: """Generate a panel-based plot from a single plot function. Parameters @@ -965,6 +984,8 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig A function which plots one parameter on an Axes object, given its index. indices : list[int] The list of indices to pass into ``plot_func``. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. Returns ------- @@ -974,10 +995,18 @@ def panel_plot_helper(plot_func: Callable, indices: list[int]) -> matplotlib.fig """ nplots = len(indices) nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots)) - fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] + + if fig is None: + fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] + else: + fig.clf() + fig.subplots(nrows, ncols) axs = fig.get_axes() for plot_num, index in enumerate(indices): + axs[plot_num].tick_params(which="both", labelsize="medium") + axs[plot_num].xaxis.offsetText.set_fontsize("small") + axs[plot_num].yaxis.offsetText.set_fontsize("small") plot_func(axs[plot_num], index) # blank unused plots @@ -998,6 +1027,7 @@ def plot_hists( dict[Literal["normal", "lognor", "kernel", None]], Literal["normal", "lognor", "kernel", None] ] = None, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, **hist_settings, ): @@ -1031,6 +1061,8 @@ def plot_hists( e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. hist_settings : @@ -1090,6 +1122,7 @@ def validate_dens_type(dens_type: Union[str, None], param: str): **hist_settings, ), params, + fig, ) if return_fig: return fig @@ -1102,6 +1135,7 @@ def plot_chain( params: Union[list[Union[int, str]], None] = None, maxpoints: int = 15000, block: bool = False, + fig: Optional[matplotlib.pyplot.figure] = None, return_fig: bool = False, ): """Plot the MCMC chain for each parameter of a Bayesian analysis. @@ -1117,6 +1151,8 @@ def plot_chain( The maximum number of points to plot for each parameter. block : bool, default False Whether Python should block until the plot is closed. + fig : matplotlib.pyplot.figure, optional + The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -1127,7 +1163,7 @@ def plot_chain( """ chain = results.chain - nsimulations, nplots = chain.shape + nsimulations, _ = chain.shape # skip is to evenly distribute points plotted # all points will be plotted if maxpoints < nsimulations skip = max(floor(nsimulations / maxpoints), 1) @@ -1142,9 +1178,9 @@ def plot_chain( def plot_one_chain(axes: Axes, i: int): axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip]) - axes.set_title(results.fitNames[i]) + axes.set_title(results.fitNames[i], fontsize="small") - fig = panel_plot_helper(plot_one_chain, params) + fig = panel_plot_helper(plot_one_chain, params, fig=fig) if return_fig: return fig plt.show(block=block) diff --git a/setup.py b/setup.py index b5871644..4c996362 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,7 @@ class BuildExt(build_ext): """A custom build extension for adding compiler-specific options.""" c_opts = { - "msvc": ["/O2", "/EHsc"], + "msvc": ["/O2", "/EHsc", "/openmp"], "unix": ["-O2", "-fopenmp", "-std=c++11"], } l_opts = { diff --git a/tests/test_controls.py b/tests/test_controls.py index 72f0c745..61d68331 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -3,6 +3,7 @@ import contextlib import os import tempfile +from pathlib import Path from typing import Any, Union import pydantic @@ -45,6 +46,30 @@ def test_extra_property_error() -> None: controls.test = 1 +@pytest.mark.parametrize( + "inputs", + [ + {"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66}, + {"procedure": "simplex"}, + {"procedure": "dream", "nSamples": 504, "nChains": 1200}, + {"procedure": "de", "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm}, + {"procedure": "ns", "nMCMC": 4, "propScale": 0.6}, + ], +) +def test_save_load(inputs): + """Test that saving and loading controls returns the same controls.""" + + original_controls = Controls(**inputs) + with tempfile.TemporaryDirectory() as tmp: + # ignore relative path warnings + path = Path(tmp, "controls.json") + original_controls.save(path) + converted_controls = Controls.load(path) + + for field in Controls.model_fields: + assert getattr(converted_controls, field) == getattr(original_controls, field) + + class TestCalculate: """Tests the Calculate class.""" diff --git a/tests/test_plotting.py b/tests/test_plotting.py index c42bfeea..222d5142 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -50,7 +50,8 @@ def fig(request) -> plt.figure: """Creates the fixture for the tests.""" plt.close("all") figure = plt.subplots(1, 2)[0] - return RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data()) + RATplot.plot_ref_sld_helper(fig=figure, data=domains_data() if request.param else data()) + return figure @pytest.fixture @@ -68,7 +69,8 @@ def bayes_fig(request) -> plt.figure: for sld in dat.sldProfiles ], } - return RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals) + RATplot.plot_ref_sld_helper(data=dat, fig=figure, confidence_intervals=confidence_intervals) + return figure @pytest.mark.parametrize("fig", [False, True], indirect=True) @@ -120,8 +122,7 @@ def test_ref_sld_color_formatting(fig: plt.figure) -> None: assert sld_plot.get_lines()[i].get_color() == sld_plot.get_lines()[i + 1].get_color() -@pytest.mark.parametrize("bayes", [65, 95]) -def test_ref_sld_bayes(fig, bayes_fig, bayes): +def test_ref_sld_bayes(fig, bayes_fig): """Test that shading is correctly added to the figure when confidence intervals are supplied.""" # the shading is of type PolyCollection for axes in fig.axes: @@ -137,7 +138,7 @@ def test_sld_profile_function_call(mock: MagicMock) -> None: """Tests the makeSLDProfile function called with correct args. """ - RATplot.plot_ref_sld_helper(data()) + RATplot.plot_ref_sld_helper(data(), plt.subplots(1, 2)[0]) assert mock.call_count == 3 assert mock.call_args_list[0].args[0] == 2.07e-06 @@ -211,9 +212,9 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r def test_ref_sld_subplot_correction(): """Test that if an incorrect number of subplots is corrected in the figure helper.""" fig = plt.subplots(1, 3)[0] - ref_sld_fig = RATplot.plot_ref_sld_helper(data=data(), fig=fig) - assert ref_sld_fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) - assert len(ref_sld_fig.axes) == 2 + RATplot.plot_ref_sld_helper(data=data(), fig=fig) + assert fig.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) + assert len(fig.axes) == 2 @patch("ratapi.utils.plotting.plot_ref_sld_helper") From 8ea392b7437d1bdcc9643f157cb534767af292c5 Mon Sep 17 00:00:00 2001 From: Paul Sharp <44529197+DrPaulSharp@users.noreply.github.com> Date: Thu, 2 Oct 2025 12:26:28 +0100 Subject: [PATCH 2/4] Adds to "pyproject.toml" to enable setup of uv environment (#182) * Adds code to enable setup of uv environment * Fixes linting errors * Sorts dependencies into groups * Moves dev dependencies into optional requirements * Adds conflicts for uv * Updates submodule with empty layers bug fix * Fixes code hanging bug * Reverts typing change --- .github/workflows/run_tests.yml | 2 +- .gitignore | 3 ++ README.md | 4 +- cpp/RAT | 2 +- pyproject.toml | 50 +++++++++++++++++- ratapi/classlist.py | 19 +++---- ratapi/controls.py | 5 +- ratapi/events.py | 8 +-- ratapi/inputs.py | 4 +- ratapi/outputs.py | 14 +++--- ratapi/project.py | 26 +++++----- ratapi/utils/convert.py | 11 ++-- ratapi/utils/custom_errors.py | 4 +- ratapi/utils/enums.py | 4 +- ratapi/utils/orso.py | 9 ++-- ratapi/utils/plotting.py | 89 ++++++++++++++++++--------------- ratapi/wrappers.py | 2 +- requirements.txt | 14 ------ setup.py | 20 -------- tests/test_classlist.py | 20 ++++---- tests/test_controls.py | 8 +-- tests/test_enums.py | 2 +- tests/test_inputs.py | 5 +- tests/test_models.py | 2 +- tests/test_orso_utils.py | 4 +- tests/test_plotting.py | 10 ++-- tests/test_project.py | 6 +-- tests/utils.py | 8 +-- 28 files changed, 189 insertions(+), 166 deletions(-) delete mode 100644 requirements.txt diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 6936d3ca..c2018067 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -63,5 +63,5 @@ jobs: - name: Install and Test with pytest run: | export PATH="$pythonLocation:$PATH" - python -m pip install -e .[Dev,Orso] + python -m pip install -e .[dev,orso] pytest tests/ --cov=ratapi --cov-report=term diff --git a/.gitignore b/.gitignore index fde40baa..51a84ffe 100644 --- a/.gitignore +++ b/.gitignore @@ -40,3 +40,6 @@ dist/* # Jupyter notebook checkpoints .ipynb_checkpoints/* + +# Lock file for uv env +uv.lock diff --git a/README.md b/README.md index e4566708..46cc348e 100644 --- a/README.md +++ b/README.md @@ -12,11 +12,11 @@ To install in local directory: matlabengine is an optional dependency only required for Matlab custom functions. The version of matlabengine should match the version of Matlab installed on the machine. This can be installed as shown below: - pip install -e .[Matlab-2023a] + pip install -e .[matlab-2023a] Development dependencies can be installed as shown below - pip install -e .[Dev] + pip install -e .[dev] To build wheel: diff --git a/cpp/RAT b/cpp/RAT index aae3dc14..79937719 160000 --- a/cpp/RAT +++ b/cpp/RAT @@ -1 +1 @@ -Subproject commit aae3dc141b6a10c6e10dfb47cd62e07a2a11857d +Subproject commit 7993771968fa7335528c4f14ef44393f0b607953 diff --git a/pyproject.toml b/pyproject.toml index d52f72a5..54a6c8d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,53 @@ requires = [ ] build-backend = 'setuptools.build_meta' +[project] +name = "ratapi" +version = "0.0.0.dev8" +description = "Python extension for the Reflectivity Analysis Toolbox (RAT)" +readme = "README.md" +requires-python = ">=3.10" +dependencies = [ + "matplotlib>=3.8.3", + "numpy>=1.20", + "prettytable>=3.9.0", + "pydantic>=2.7.2", + "scipy>=1.13.1", + "strenum>=0.4.15 ; python_full_version < '3.11'", + "tqdm>=4.66.5", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.4.0", + "pytest-cov>=4.1.0", + "ruff>=0.4.10" +] +orso = [ + "orsopy>=1.2.1", + "pint>=0.24.4" +] +matlab_latest = ["matlabengine"] +matlab_2025b = ["matlabengine == 25.2.*"] +matlab_2025a = ["matlabengine == 25.1.2"] +matlab_2024b = ["matlabengine == 24.2.2"] +matlab_2024a = ["matlabengine == 24.1.4"] +matlab_2023b = ["matlabengine == 23.2.3"] +matlab_2023a = ["matlabengine == 9.14.3"] + +[tool.uv] +conflicts = [ + [ + { extra = "matlab_latest" }, + { extra = "matlab_2025b" }, + { extra = "matlab_2025a" }, + { extra = "matlab_2024b" }, + { extra = "matlab_2024a" }, + { extra = "matlab_2023b" }, + { extra = "matlab_2023a" }, + ], +] + [tool.ruff] line-length = 120 extend-exclude = ["*.ipynb"] @@ -24,7 +71,8 @@ ignore = ["SIM103", # needless bool "D105", # undocumented __init__ "D107", # undocumented magic method "D203", # blank line before class docstring - "D213"] # multi line summary should start at second line + "D213", # multi line summary should start at second line + "UP038"] # non pep604 isinstance - to be removed # ignore docstring lints in the tests and install script [tool.ruff.lint.per-file-ignores] diff --git a/ratapi/classlist.py b/ratapi/classlist.py index f0a61d3b..29637a5f 100644 --- a/ratapi/classlist.py +++ b/ratapi/classlist.py @@ -5,7 +5,7 @@ import importlib import warnings from collections.abc import Sequence -from typing import Any, Generic, TypeVar, Union +from typing import Any, Generic, TypeVar import numpy as np import prettytable @@ -38,7 +38,7 @@ class ClassList(collections.UserList, Generic[T]): """ - def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "name") -> None: + def __init__(self, init_list: Sequence[T] | T = None, name_field: str = "name") -> None: self.name_field = name_field # Set input as list if necessary @@ -114,7 +114,7 @@ def __str__(self): output = str(self.data) return output - def __getitem__(self, index: Union[int, slice, str, T]) -> T: + def __getitem__(self, index: int | slice | str | T) -> T: """Get an item by its index, name, a slice, or the object itself.""" if isinstance(index, (int, slice)): return self.data[index] @@ -262,12 +262,12 @@ def insert(self, index: int, obj: T = None, **kwargs) -> None: self._validate_name_field(kwargs) self.data.insert(index, self._class_handle(**kwargs)) - def remove(self, item: Union[T, str]) -> None: + def remove(self, item: T | str) -> None: """Remove an object from the ClassList using either the object itself or its ``name_field`` value.""" item = self._get_item_from_name_field(item) self.data.remove(item) - def count(self, item: Union[T, str]) -> int: + def count(self, item: T | str) -> int: """Return the number of times an object appears in the ClassList. This method can use either the object itself or its ``name_field`` value. @@ -276,7 +276,7 @@ def count(self, item: Union[T, str]) -> int: item = self._get_item_from_name_field(item) return self.data.count(item) - def index(self, item: Union[T, str], offset: bool = False, *args) -> int: + def index(self, item: T | str, offset: bool = False, *args) -> int: """Return the index of a particular object in the ClassList. This method can use either the object itself or its ``name_field`` value. @@ -309,7 +309,7 @@ def union(self, other: Sequence[T]) -> None: ] ) - def set_fields(self, index: Union[int, slice, str, T], **kwargs) -> None: + def set_fields(self, index: int | slice | str | T, **kwargs) -> None: """Assign the values of an existing object's attributes using keyword arguments.""" self._validate_name_field(kwargs) pydantic_object = False @@ -519,7 +519,7 @@ def _check_classes(self, input_list: Sequence[T]) -> None: f"In the input list:\n{newline.join(error for error in error_list)}\n" ) - def _get_item_from_name_field(self, value: Union[T, str]) -> Union[T, str]: + def _get_item_from_name_field(self, value: T | str) -> T | str: """Return the object with the given value of the ``name_field`` attribute in the ClassList. Parameters @@ -577,11 +577,12 @@ def _determine_class_handle(input_list: Sequence[T]): @classmethod def __get_pydantic_core_schema__(cls, source: Any, handler): # import here so that the ClassList can be instantiated and used without Pydantic installed + from typing import get_args, get_origin + from pydantic import ValidatorFunctionWrapHandler from pydantic.types import ( core_schema, # import core_schema through here rather than making pydantic_core a dependency ) - from typing_extensions import get_args, get_origin # if annotated with a class, get the item type of that class origin = get_origin(source) diff --git a/ratapi/controls.py b/ratapi/controls.py index ea62f5d2..06c457b6 100644 --- a/ratapi/controls.py +++ b/ratapi/controls.py @@ -5,7 +5,6 @@ import tempfile import warnings from pathlib import Path -from typing import Union import prettytable from pydantic import ( @@ -233,7 +232,7 @@ def delete_IPC(self): os.remove(self._IPCFilePath) return None - def save(self, filepath: Union[str, Path] = "./controls.json"): + def save(self, filepath: str | Path = "./controls.json"): """Save a controls object to a JSON file. Parameters @@ -245,7 +244,7 @@ def save(self, filepath: Union[str, Path] = "./controls.json"): filepath.write_text(self.model_dump_json()) @classmethod - def load(cls, path: Union[str, Path]) -> "Controls": + def load(cls, path: str | Path) -> "Controls": """Load a controls object from file. Parameters diff --git a/ratapi/events.py b/ratapi/events.py index 71993ddb..2383159c 100644 --- a/ratapi/events.py +++ b/ratapi/events.py @@ -1,12 +1,12 @@ """Hooks for connecting to run callback events.""" import os -from typing import Callable, Union +from collections.abc import Callable from ratapi.rat_core import EventBridge, EventTypes, PlotEventData, ProgressEventData -def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEventData]) -> None: +def notify(event_type: EventTypes, data: str | PlotEventData | ProgressEventData) -> None: """Call registered callbacks with data when event type has been triggered. Parameters @@ -22,7 +22,7 @@ def notify(event_type: EventTypes, data: Union[str, PlotEventData, ProgressEvent callback(data) -def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, PlotEventData, ProgressEventData]], None]]: +def get_event_callback(event_type: EventTypes) -> list[Callable[[str | PlotEventData | ProgressEventData], None]]: """Return all callbacks registered for the given event type. Parameters @@ -39,7 +39,7 @@ def get_event_callback(event_type: EventTypes) -> list[Callable[[Union[str, Plot return list(__event_callbacks[event_type]) -def register(event_type: EventTypes, callback: Callable[[Union[str, PlotEventData, ProgressEventData]], None]) -> None: +def register(event_type: EventTypes, callback: Callable[[str | PlotEventData | ProgressEventData], None]) -> None: """Register a new callback for the event type. Parameters diff --git a/ratapi/inputs.py b/ratapi/inputs.py index 7d0a8720..a537b214 100644 --- a/ratapi/inputs.py +++ b/ratapi/inputs.py @@ -3,7 +3,7 @@ import importlib import os import pathlib -from typing import Callable, Union +from collections.abc import Callable import numpy as np @@ -23,7 +23,7 @@ } -def get_python_handle(file_name: str, function_name: str, path: Union[str, pathlib.Path] = "") -> Callable: +def get_python_handle(file_name: str, function_name: str, path: str | pathlib.Path = "") -> Callable: """Get the function handle from a function defined in a python module located anywhere within the filesystem. Parameters diff --git a/ratapi/outputs.py b/ratapi/outputs.py index b60547d8..2add1653 100644 --- a/ratapi/outputs.py +++ b/ratapi/outputs.py @@ -3,7 +3,7 @@ import json from dataclasses import dataclass from pathlib import Path -from typing import Any, Optional, Union +from typing import Any, Union import numpy as np @@ -244,7 +244,7 @@ def __str__(self): output += get_field_string(key, value, 100) return output - def save(self, filepath: Union[str, Path] = "./results.json"): + def save(self, filepath: str | Path = "./results.json"): """Save the Results object to a JSON file. Parameters @@ -258,7 +258,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"): filepath.write_text(json.dumps(json_dict)) @classmethod - def load(cls, path: Union[str, Path]) -> Union["Results", "BayesResults"]: + def load(cls, path: str | Path) -> Union["Results", "BayesResults"]: """Load a Results object from file. Parameters @@ -538,7 +538,7 @@ class BayesResults(Results): nestedSamplerOutput: NestedSamplerOutput chain: np.ndarray - def save(self, filepath: Union[str, Path] = "./results.json"): + def save(self, filepath: str | Path = "./results.json"): """Save the BayesResults object to a JSON file. Parameters @@ -574,7 +574,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"): filepath.write_text(json.dumps(json_dict)) -def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict: +def write_core_results_fields(results: Results | BayesResults, json_dict: dict | None = None) -> dict: """Modify the values of the fields that appear in both Results and BayesResults when saving to a json file. Parameters @@ -684,8 +684,8 @@ def read_bayes_results_fields(results_dict: dict) -> dict: def make_results( procedure: Procedures, output_results: ratapi.rat_core.OutputResult, - bayes_results: Optional[ratapi.rat_core.OutputBayesResult] = None, -) -> Union[Results, BayesResults]: + bayes_results: ratapi.rat_core.OutputBayesResult | None = None, +) -> Results | BayesResults: """Initialise a python Results or BayesResults object using the outputs from a RAT calculation. Parameters diff --git a/ratapi/project.py b/ratapi/project.py index 88555539..fcbbaf83 100644 --- a/ratapi/project.py +++ b/ratapi/project.py @@ -5,10 +5,11 @@ import functools import json import warnings +from collections.abc import Callable from enum import Enum from pathlib import Path from textwrap import indent -from typing import Annotated, Any, Callable, Union +from typing import Annotated, Any, get_args, get_origin import numpy as np from pydantic import ( @@ -21,7 +22,6 @@ field_validator, model_validator, ) -from typing_extensions import get_args, get_origin import ratapi.models from ratapi.classlist import ClassList @@ -248,10 +248,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute data: ClassList[ratapi.models.Data] = ClassList() """Experimental data for a model.""" - layers: Union[ - Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")], - Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")], - ] = Field( + layers: ( + Annotated[ClassList[ratapi.models.Layer], Tag("no_abs")] + | Annotated[ClassList[ratapi.models.AbsorptionLayer], Tag("abs")] + ) = Field( default=ClassList(), discriminator=Discriminator( discriminate_layers, @@ -265,10 +265,10 @@ class Project(BaseModel, validate_assignment=True, extra="forbid", use_attribute domain_contrasts: ClassList[ratapi.models.DomainContrast] = ClassList() """The groups of layers required by each domain in a domains model.""" - contrasts: Union[ - Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")], - Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")], - ] = Field( + contrasts: ( + Annotated[ClassList[ratapi.models.Contrast], Tag("no_ratio")] + | Annotated[ClassList[ratapi.models.ContrastWithRatio], Tag("ratio")] + ) = Field( default=ClassList(), discriminator=Discriminator( discriminate_contrasts, @@ -577,7 +577,7 @@ def update_renamed_models(self) -> "Project": old_names = self._all_names[class_list] new_names = getattr(self, class_list).get_names() if len(old_names) == len(new_names): - name_diff = [(old, new) for (old, new) in zip(old_names, new_names) if old != new] + name_diff = [(old, new) for (old, new) in zip(old_names, new_names, strict=False) if old != new] for old_name, new_name in name_diff: for field in fields_to_update: project_field = getattr(self, field.attribute) @@ -927,7 +927,7 @@ def classlist_script(name, classlist): + "\n)" ) - def save(self, filepath: Union[str, Path] = "./project.json"): + def save(self, filepath: str | Path = "./project.json"): """Save a project to a JSON file. Parameters @@ -973,7 +973,7 @@ def make_custom_file_dict(item): filepath.write_text(json.dumps(json_dict)) @classmethod - def load(cls, path: Union[str, Path]) -> "Project": + def load(cls, path: str | Path) -> "Project": """Load a project from file. Parameters diff --git a/ratapi/utils/convert.py b/ratapi/utils/convert.py index b689317f..4e2649aa 100644 --- a/ratapi/utils/convert.py +++ b/ratapi/utils/convert.py @@ -4,7 +4,6 @@ from collections.abc import Iterable from os import PathLike from pathlib import Path -from typing import Union from numpy import array, empty from scipy.io.matlab import MatlabOpaque, loadmat @@ -15,7 +14,7 @@ from ratapi.utils.enums import Geometries, Languages, LayerModels -def r1_to_project(filename: Union[str, PathLike]) -> Project: +def r1_to_project(filename: str | PathLike) -> Project: """Read a RasCAL1 project struct as a Python `Project`. Parameters @@ -43,7 +42,7 @@ def r1_to_project(filename: Union[str, PathLike]) -> Project: layer_model = LayerModels.CustomXY layer_model = LayerModels(layer_model) - def zip_if_several(*params) -> Union[tuple, list[tuple]]: + def zip_if_several(*params) -> tuple | list[tuple]: """Zips parameters if necessary, but can handle single-item parameters. Examples @@ -64,7 +63,7 @@ def zip_if_several(*params) -> Union[tuple, list[tuple]]: """ if all(isinstance(param, Iterable) and not isinstance(param, str) for param in params): - return zip(*params) + return zip(*params, strict=False) return [params] def read_param(names, constrs, values, fits): @@ -319,8 +318,8 @@ def fix_invalid_constraints(name: str, constrs: tuple[float, float], value: floa def project_to_r1( - project: Project, filename: Union[str, PathLike] = "RAT_project", return_struct: bool = False -) -> Union[dict, None]: + project: Project, filename: str | PathLike = "RAT_project", return_struct: bool = False +) -> dict | None: """Convert a RAT Project to a RasCAL1 project struct. Parameters diff --git a/ratapi/utils/custom_errors.py b/ratapi/utils/custom_errors.py index 425cf9ef..83bf084f 100644 --- a/ratapi/utils/custom_errors.py +++ b/ratapi/utils/custom_errors.py @@ -1,13 +1,11 @@ """Defines routines for custom error handling in RAT.""" -from typing import Optional - import pydantic_core def custom_pydantic_validation_error( error_list: list[pydantic_core.ErrorDetails], - custom_error_msgs: Optional[dict[str, str]] = None, + custom_error_msgs: dict[str, str] | None = None, ) -> list[pydantic_core.ErrorDetails]: """Give Pydantic errors a better custom message with extraneous information removed. diff --git a/ratapi/utils/enums.py b/ratapi/utils/enums.py index 24c50cbc..313f04c7 100644 --- a/ratapi/utils/enums.py +++ b/ratapi/utils/enums.py @@ -1,7 +1,5 @@ """The Enum values used in the parameters of various ratapi classes and functions.""" -from typing import Union - try: from enum import StrEnum except ImportError: @@ -92,7 +90,7 @@ class Strategies(RATEnum): or a pure recombination of parent parameter values.""" @classmethod - def _missing_(cls, value: Union[int, str]): + def _missing_(cls, value: int | str): # legacy compatibility with strategies being 1-indexed ints under the hood if isinstance(value, int): if value < 1 or value > 6: diff --git a/ratapi/utils/orso.py b/ratapi/utils/orso.py index 403597b4..2d0345fc 100644 --- a/ratapi/utils/orso.py +++ b/ratapi/utils/orso.py @@ -4,7 +4,6 @@ from itertools import count from pathlib import Path from textwrap import shorten -from typing import Union import orsopy import prettytable @@ -26,7 +25,7 @@ class ORSOProject: """ - def __init__(self, filepath: Union[str, Path], absorption: bool = False): + def __init__(self, filepath: str | Path, absorption: bool = False): ort_data = load_orso(filepath) datasets = [Data(name=dataset.info.data_source.sample.name, data=dataset.data) for dataset in ort_data] # orso datasets in the same file can have repeated names! @@ -75,7 +74,7 @@ class ORSOSample: bulk_in: Parameter bulk_out: Parameter parameters: ClassList[Parameter] - layers: Union[ClassList[Layer], ClassList[AbsorptionLayer]] + layers: ClassList[Layer] | ClassList[AbsorptionLayer] model: list[str] def __str__(self): @@ -94,8 +93,8 @@ def __str__(self): def orso_model_to_rat( - model: Union[orsopy.fileio.model_language.SampleModel, str], absorption: bool = False -) -> Union[ORSOSample, None]: + model: orsopy.fileio.model_language.SampleModel | str, absorption: bool = False +) -> ORSOSample | None: """Get information from an ORSO SampleModel object. Parameters diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index c2823b8f..50e0d448 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -2,12 +2,14 @@ import copy import types +from collections.abc import Callable from functools import partial, wraps from math import ceil, floor, sqrt from statistics import stdev -from typing import Callable, Literal, Optional, Union +from typing import Literal import matplotlib +import matplotlib.figure import matplotlib.pyplot as plt import matplotlib.transforms as mtransforms import numpy as np @@ -47,7 +49,9 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool if shift_value < 1 or shift_value > 100: raise ValueError("Parameter `shift_value` must be between 1 and 100") - for i, (r, data, sld) in enumerate(zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles)): + for i, (r, data, sld) in enumerate( + zip(event_data.reflectivity, event_data.shiftedData, event_data.sldProfiles, strict=False) + ): # Calculate the divisor div = 1 if i == 0 and not q4 else 10 ** ((i / 100) * shift_value) q4_data = 1 if not q4 or not event_data.dataPresent[i] else data[:, 0] ** 4 @@ -94,9 +98,9 @@ def _extract_plot_data(event_data: PlotEventData, q4: bool, show_error_bar: bool def plot_ref_sld_helper( data: PlotEventData, - fig: matplotlib.pyplot.figure, + fig: matplotlib.figure.Figure, delay: bool = True, - confidence_intervals: Union[dict, None] = None, + confidence_intervals: dict | None = None, linear_x: bool = False, q4: bool = False, show_error_bar: bool = True, @@ -112,7 +116,7 @@ def plot_ref_sld_helper( data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure + fig : matplotlib.figure.Figure The figure object that has two subplots delay : bool, default: True Controls whether to delay 0.005s after plot is created @@ -230,9 +234,9 @@ def plot_ref_sld_helper( def plot_ref_sld( project: ratapi.Project, - results: Union[ratapi.outputs.Results, ratapi.outputs.BayesResults], + results: ratapi.outputs.Results | ratapi.outputs.BayesResults, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, bayes: Literal[65, 95, None] = None, linear_x: bool = False, @@ -241,7 +245,7 @@ def plot_ref_sld( show_grid: bool = False, show_legend: bool = True, shift_value: float = 100, -) -> Union[plt.Figure, None]: +) -> plt.Figure | None: """Plot the reflectivity and SLD profiles. Parameters @@ -252,7 +256,7 @@ def plot_ref_sld( The result from the calculation block : bool, default: False Indicates the plot should block until it is closed - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object that has two subplots return_fig : bool, default False If True, return the figure instead of displaying it. @@ -319,10 +323,12 @@ def plot_ref_sld( ], } # For a shaded plot, use the mean values from predictionIntervals - for reflectivity, mean_reflectivity in zip(data.reflectivity, results.predictionIntervals.reflectivity): + for reflectivity, mean_reflectivity in zip( + data.reflectivity, results.predictionIntervals.reflectivity, strict=False + ): reflectivity[:, 1] = mean_reflectivity[2] - for sldProfile, mean_sld_profile in zip(data.sldProfiles, results.predictionIntervals.sld): - for sld, mean_sld in zip(sldProfile, mean_sld_profile): + for sldProfile, mean_sld_profile in zip(data.sldProfiles, results.predictionIntervals.sld, strict=False): + for sld, mean_sld in zip(sldProfile, mean_sld_profile, strict=False): sld[:, 1] = mean_sld[2] else: raise ValueError( @@ -366,7 +372,7 @@ class BlittingSupport: data : PlotEventData The plot event data that contains all the information to generate the ref and sld plots - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure class that has two subplots linear_x : bool, default: False Controls whether the x-axis on reflectivity plot uses the linear scale @@ -471,7 +477,9 @@ def adjust_error_bar(self, error_bar_container, x, y, y_error): y_error_top = y_base + y_error y_error_bottom = y_base - y_error - new_segments_y = [np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom)] + new_segments_y = [ + np.array([[x, yt], [x, yb]]) for x, yt, yb in zip(x_base, y_error_top, y_error_bottom, strict=False) + ] bars_y.set_segments(new_segments_y) def update_plot(self, data): @@ -628,7 +636,7 @@ def inner(results, *args, **kwargs): return decorator -def name_to_index(param: Union[str, int], names: list[str]): +def name_to_index(param: str | int, names: list[str]): """Convert parameter names to indices.""" if isinstance(param, str): if param not in names: @@ -645,14 +653,14 @@ def name_to_index(param: Union[str, int], names: list[str]): @assert_bayesian("Corner") def plot_corner( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, smooth: bool = True, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, - hist_kwargs: Union[dict, None] = None, - hist2d_kwargs: Union[dict, None] = None, - progress_callback: Union[Callable[[int, int], None], None] = None, + hist_kwargs: dict | None = None, + hist2d_kwargs: dict | None = None, + progress_callback: Callable[[int, int], None] | None = None, ): """Create a corner plot from a Bayesian analysis. @@ -667,7 +675,7 @@ def plot_corner( Whether to apply Gaussian smoothing to the corner plot. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -750,11 +758,11 @@ def plot_corner( @assert_bayesian("Histogram") def plot_one_hist( results: ratapi.outputs.BayesResults, - param: Union[int, str], + param: int | str, smooth: bool = True, - sigma: Union[float, None] = None, + sigma: float | None = None, estimated_density: Literal["normal", "lognor", "kernel", None] = None, - axes: Union[Axes, None] = None, + axes: Axes | None = None, block: bool = False, return_fig: bool = False, **hist_settings, @@ -901,11 +909,11 @@ def _y_update_offset_text_position(axis, _bboxes, bboxes2): @assert_bayesian("Contour") def plot_contour( results: ratapi.outputs.BayesResults, - x_param: Union[int, str], - y_param: Union[int, str], + x_param: int | str, + y_param: int | str, smooth: bool = True, - sigma: Union[tuple[float], None] = None, - axes: Union[Axes, None] = None, + sigma: tuple[float] | None = None, + axes: Axes | None = None, block: bool = False, return_fig: bool = False, **hist2d_settings, @@ -974,7 +982,7 @@ def plot_contour( def panel_plot_helper( - plot_func: Callable, indices: list[int], fig: Optional[matplotlib.pyplot.figure] = None + plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None ) -> matplotlib.figure.Figure: """Generate a panel-based plot from a single plot function. @@ -984,7 +992,7 @@ def panel_plot_helper( A function which plots one parameter on an Axes object, given its index. indices : list[int] The list of indices to pass into ``plot_func``. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. Returns @@ -1020,14 +1028,13 @@ def panel_plot_helper( @assert_bayesian("Histogram") def plot_hists( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, smooth: bool = True, - sigma: Union[float, None] = None, - estimated_density: Union[ - dict[Literal["normal", "lognor", "kernel", None]], Literal["normal", "lognor", "kernel", None] - ] = None, + sigma: float | None = None, + estimated_density: dict[Literal["normal", "lognor", "kernel", None]] + | Literal["normal", "lognor", "kernel", None] = None, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, **hist_settings, ): @@ -1061,7 +1068,7 @@ def plot_hists( e.g. to apply 'normal' to all unset parameters, set `estimated_density = {'default': 'normal'}`. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. @@ -1085,7 +1092,7 @@ def plot_hists( if estimated_density is not None: - def validate_dens_type(dens_type: Union[str, None], param: str): + def validate_dens_type(dens_type: str | None, param: str): """Check estimated density is a supported type.""" if dens_type not in [None, "normal", "lognor", "kernel"]: raise ValueError( @@ -1132,10 +1139,10 @@ def validate_dens_type(dens_type: Union[str, None], param: str): @assert_bayesian("Chain") def plot_chain( results: ratapi.outputs.BayesResults, - params: Union[list[Union[int, str]], None] = None, + params: list[int | str] | None = None, maxpoints: int = 15000, block: bool = False, - fig: Optional[matplotlib.pyplot.figure] = None, + fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, ): """Plot the MCMC chain for each parameter of a Bayesian analysis. @@ -1151,7 +1158,7 @@ def plot_chain( The maximum number of points to plot for each parameter. block : bool, default False Whether Python should block until the plot is closed. - fig : matplotlib.pyplot.figure, optional + fig : matplotlib.figure.Figure, optional The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. diff --git a/ratapi/wrappers.py b/ratapi/wrappers.py index 39021e29..74eda41e 100644 --- a/ratapi/wrappers.py +++ b/ratapi/wrappers.py @@ -2,8 +2,8 @@ import os import pathlib +from collections.abc import Callable from contextlib import suppress -from typing import Callable import numpy as np from numpy.typing import ArrayLike diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index cee9a790..00000000 --- a/requirements.txt +++ /dev/null @@ -1,14 +0,0 @@ -numpy >= 1.20 -scipy >= 1.13.1 -prettytable >= 3.9.0 -pybind11 >= 2.4 -pydantic >= 2.7.2 -pytest >= 7.4.0 -pytest-cov >= 4.1.0 -matplotlib >= 3.8.3 -StrEnum >= 0.4.15; python_version < '3.11' -ruff >= 0.4.10 -scipy >= 1.13.1 -tqdm >= 4.66.5 -orsopy >= 1.2.1 -pint >= 0.24.4 diff --git a/setup.py b/setup.py index 4c996362..3c21a01f 100644 --- a/setup.py +++ b/setup.py @@ -165,25 +165,5 @@ def build_libraries(self, libraries): cmdclass={"build_clib": BuildClib, "build_ext": BuildExt}, libraries=[libevent], ext_modules=ext_modules, - python_requires=">=3.10", - install_requires=[ - "numpy >= 1.20", - "prettytable >= 3.9.0", - "pydantic >= 2.7.2", - "matplotlib >= 3.8.3", - "scipy >= 1.13.1", - "tqdm >= 4.66.5", - ], - extras_require={ - ':python_version < "3.11"': ["StrEnum >= 0.4.15"], - "Dev": ["pytest>=7.4.0", "pytest-cov>=4.1.0", "ruff>=0.4.10"], - "Orso": ["orsopy>=1.2.1", "pint>=0.24.4"], - "Matlab_latest": ["matlabengine"], - "Matlab_2025a": ["matlabengine == 25.1.*"], - "Matlab_2024b": ["matlabengine == 24.2.2"], - "Matlab_2024a": ["matlabengine == 24.1.4"], - "Matlab_2023b": ["matlabengine == 23.2.3"], - "Matlab_2023a": ["matlabengine == 9.14.3"], - }, zip_safe=False, ) diff --git a/tests/test_classlist.py b/tests/test_classlist.py index 04130ff8..98d9440c 100644 --- a/tests/test_classlist.py +++ b/tests/test_classlist.py @@ -5,7 +5,7 @@ import warnings from collections import deque from collections.abc import Iterable, Sequence -from typing import Any, Union +from typing import Any import prettytable import pytest @@ -611,7 +611,7 @@ def test_insert_kwargs_same_name(two_name_class_list: ClassList, new_values: dic (InputAttributes(name="Bob")), ], ) -def test_remove(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: +def test_remove(two_name_class_list: ClassList, remove_value: object | str) -> None: """We should be able to remove an object either by the value of the name_field or by specifying the object itself. """ @@ -626,7 +626,7 @@ def test_remove(two_name_class_list: ClassList, remove_value: Union[object, str] (InputAttributes(name="Eve")), ], ) -def test_remove_not_present(two_name_class_list: ClassList, remove_value: Union[object, str]) -> None: +def test_remove_not_present(two_name_class_list: ClassList, remove_value: object | str) -> None: """If we remove an object not included in the ClassList we should raise a ValueError.""" with pytest.raises(ValueError, match=re.escape("list.remove(x): x not in list")): two_name_class_list.remove(remove_value) @@ -641,7 +641,7 @@ def test_remove_not_present(two_name_class_list: ClassList, remove_value: Union[ (InputAttributes(name="Eve"), 0), ], ) -def test_count(two_name_class_list: ClassList, count_value: Union[object, str], expected_count: int) -> None: +def test_count(two_name_class_list: ClassList, count_value: object | str, expected_count: int) -> None: """We should be able to determine the number of times an object is in the ClassList using either the object itself or its name_field value. """ @@ -655,7 +655,7 @@ def test_count(two_name_class_list: ClassList, count_value: Union[object, str], (InputAttributes(name="Bob"), 1), ], ) -def test_index(two_name_class_list: ClassList, index_value: Union[object, str], expected_index: int) -> None: +def test_index(two_name_class_list: ClassList, index_value: object | str, expected_index: int) -> None: """We should be able to find the index of an object in the ClassList either by its name_field value or by specifying the object itself. """ @@ -671,7 +671,7 @@ def test_index(two_name_class_list: ClassList, index_value: Union[object, str], ) def test_index_offset( two_name_class_list: ClassList, - index_value: Union[object, str], + index_value: object | str, offset: int, expected_index: int, ) -> None: @@ -688,7 +688,7 @@ def test_index_offset( (InputAttributes(name="Eve")), ], ) -def test_index_not_present(two_name_class_list: ClassList, index_value: Union[object, str]) -> None: +def test_index_not_present(two_name_class_list: ClassList, index_value: object | str) -> None: """If we try to find the index of an object not included in the ClassList we should raise a ValueError.""" # with pytest.raises(ValueError, match=f"'{index_value}' is not in list") as e: with pytest.raises(ValueError): @@ -741,7 +741,7 @@ def test_extend_empty_classlist(extended_list: Sequence, one_name_class_list: Cl ], ) def test_set_fields( - two_name_class_list: ClassList, index: Union[int, str], new_values: dict[str, Any], expected_classlist: ClassList + two_name_class_list: ClassList, index: int | str, new_values: dict[str, Any], expected_classlist: ClassList ) -> None: """We should be able to set field values in an element of a ClassList using keyword arguments.""" class_list = two_name_class_list @@ -963,7 +963,7 @@ def test__check_classes_different_classes(input_list: Sequence) -> None: def test__get_item_from_name_field( two_name_class_list: ClassList, value: str, - expected_output: Union[object, str], + expected_output: object | str, ) -> None: """When we input the name_field value of an object defined in the ClassList, we should return the object. If the value is not the name_field of an object defined in the ClassList, we should return the value. @@ -1044,7 +1044,7 @@ class NestedModel(pydantic.BaseModel): submodels_list = [{"i": 3, "s": "hello", "f": 3.0}, {"i": 4, "s": "hi", "f": 3.14}] model = NestedModel(submodels=submodels_list) - for submodel, exp_dict in zip(model.submodels, submodels_list): + for submodel, exp_dict in zip(model.submodels, submodels_list, strict=False): for key, value in exp_dict.items(): assert getattr(submodel, key) == value diff --git a/tests/test_controls.py b/tests/test_controls.py index 61d68331..ec8cb128 100644 --- a/tests/test_controls.py +++ b/tests/test_controls.py @@ -4,7 +4,7 @@ import os import tempfile from pathlib import Path -from typing import Any, Union +from typing import Any import pydantic import pytest @@ -366,7 +366,7 @@ def test_set_non_simplex_properties(self, wrong_property: str, value: Any) -> No ("maxIterations", -50), ], ) - def test_simplex_property_errors(self, control_property: str, value: Union[float, int]) -> None: + def test_simplex_property_errors(self, control_property: str, value: float | int) -> None: """Tests the property errors of Simplex class.""" with pytest.raises(pydantic.ValidationError, match="Input should be greater than 0"): setattr(self.simplex, control_property, value) @@ -538,7 +538,7 @@ def test_de_crossoverProbability_error(self, value: int, msg: str) -> None: def test_de_targetValue_numGenerations_populationSize_error( self, control_property: str, - value: Union[int, float], + value: int | float, ) -> None: """Tests the targetValue, numGenerations, populationSize setter error in DE class.""" with pytest.raises(pydantic.ValidationError, match="Input should be greater than or equal to 1"): @@ -693,7 +693,7 @@ def test_set_non_ns_properties(self, wrong_property: str, value: Any) -> None: ("nLive", -500, 1), ], ) - def test_ns_setter_error(self, control_property: str, value: Union[int, float], bound: int) -> None: + def test_ns_setter_error(self, control_property: str, value: int | float, bound: int) -> None: """Tests the nMCMC, nsTolerance, nLive setter error in NS class.""" with pytest.raises(pydantic.ValidationError, match=f"Input should be greater than or equal to {bound}"): setattr(self.ns, control_property, value) diff --git a/tests/test_enums.py b/tests/test_enums.py index 9984185c..3d07c308 100644 --- a/tests/test_enums.py +++ b/tests/test_enums.py @@ -1,6 +1,6 @@ """Tests the enums module.""" -from typing import Callable +from collections.abc import Callable import pytest diff --git a/tests/test_inputs.py b/tests/test_inputs.py index e7a1560e..bf9a2e17 100644 --- a/tests/test_inputs.py +++ b/tests/test_inputs.py @@ -736,11 +736,12 @@ def check_problem_equal(actual_problem, expected_problem) -> None: # Data field is a numpy array assert [ - actual_data == expected_data for (actual_data, expected_data) in zip(actual_problem.data, expected_problem.data) + actual_data == expected_data + for (actual_data, expected_data) in zip(actual_problem.data, expected_problem.data, strict=False) ] # Need to account for "NaN" entries in layersDetails and contrastCustomFiles field - for actual_layer, expected_layer in zip(actual_problem.layersDetails, expected_problem.layersDetails): + for actual_layer, expected_layer in zip(actual_problem.layersDetails, expected_problem.layersDetails, strict=False): assert (actual_layer == expected_layer) or ["NaN" if np.isnan(el) else el for el in actual_layer] == [ "NaN" if np.isnan(el) else el for el in expected_layer ] diff --git a/tests/test_models.py b/tests/test_models.py index 0b00dea9..906a5d67 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -2,7 +2,7 @@ import pathlib import re -from typing import Callable +from collections.abc import Callable import numpy as np import pydantic diff --git a/tests/test_orso_utils.py b/tests/test_orso_utils.py index a8c0791a..60f04494 100644 --- a/tests/test_orso_utils.py +++ b/tests/test_orso_utils.py @@ -95,7 +95,7 @@ def test_load_ort_data(test_data): actual_data = ORSOProject(Path(TEST_DIR_PATH, test_data)).data assert len(actual_data) == len(expected_data) - for actual_dataset, expected_dataset in zip(actual_data, expected_data): + for actual_dataset, expected_dataset in zip(actual_data, expected_data, strict=False): np.testing.assert_array_equal(actual_dataset.data, expected_dataset) @@ -118,5 +118,5 @@ def test_load_ort_project(test_data, expected_data): assert sample.parameters == exp_project.parameters[1:] assert sample.layers == exp_project.layers - for data, exp_data in zip(ort_data.data, exp_project.data[1:]): + for data, exp_data in zip(ort_data.data, exp_project.data[1:], strict=False): np.testing.assert_array_equal(data.data, exp_data.data) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 222d5142..710729a5 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -194,10 +194,14 @@ def test_plot_ref_sld(mock: MagicMock, input_project, reflectivity_calculation_r assert figure.axes[0].get_subplotspec().get_gridspec().get_geometry() == (1, 2) assert len(figure.axes) == 2 - for reflectivity, reflectivity_results in zip(data.reflectivity, reflectivity_calculation_results.reflectivity): + for reflectivity, reflectivity_results in zip( + data.reflectivity, reflectivity_calculation_results.reflectivity, strict=False + ): assert (reflectivity == reflectivity_results).all() - for sldProfile, result_sld_profile in zip(data.sldProfiles, reflectivity_calculation_results.sldProfiles): - for sld, sld_results in zip(sldProfile, result_sld_profile): + for sldProfile, result_sld_profile in zip( + data.sldProfiles, reflectivity_calculation_results.sldProfiles, strict=False + ): + for sld, sld_results in zip(sldProfile, result_sld_profile, strict=False): assert (sld == sld_results).all() assert data.modelType == input_project.model diff --git a/tests/test_project.py b/tests/test_project.py index 7683dec9..31e13760 100644 --- a/tests/test_project.py +++ b/tests/test_project.py @@ -4,13 +4,13 @@ import re import tempfile import warnings +from collections.abc import Callable from pathlib import Path -from typing import Callable +from typing import get_args, get_origin import numpy as np import pydantic import pytest -from typing_extensions import get_args, get_origin import ratapi from ratapi.utils.enums import Calculations, LayerModels, TypeOptions @@ -667,7 +667,7 @@ def test_rename_models(test_project, model: str, fields: list[str]) -> None: getattr(test_project, model).set_fields(-1, name="New Name") model_name_lists = ratapi.project.model_names_used_in[model] - for model_name_list, field in zip(model_name_lists, fields): + for model_name_list, field in zip(model_name_lists, fields, strict=False): attribute = model_name_list.attribute assert getattr(getattr(test_project, attribute)[-1], field) == "New Name" diff --git a/tests/utils.py b/tests/utils.py index 91b5b9ae..5387b30a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -48,7 +48,7 @@ def check_results_equal(actual_results, expected_results) -> None: # The first set of fields are either 1D or 2D python lists containing numpy arrays. # Hence, we need to compare them element-wise. for list_field in ratapi.outputs.results_fields["list_fields"]: - for a, b in zip(getattr(actual_results, list_field), getattr(expected_results, list_field)): + for a, b in zip(getattr(actual_results, list_field), getattr(expected_results, list_field), strict=False): assert (a == b).all() for list_field in ratapi.outputs.results_fields["double_list_fields"]: @@ -56,7 +56,7 @@ def check_results_equal(actual_results, expected_results) -> None: expected_list = getattr(expected_results, list_field) assert len(actual_list) == len(expected_list) for i in range(len(actual_list)): - for a, b in zip(actual_list[i], expected_list[i]): + for a, b in zip(actual_list[i], expected_list[i], strict=False): assert (a == b).all() # Compare the final fields @@ -90,7 +90,7 @@ def check_bayes_fields_equal(actual_results, expected_results) -> None: assert getattr(actual_subclass, field) == getattr(expected_subclass, field) for field in ratapi.outputs.bayes_results_fields["list_fields"][subclass]: - for a, b in zip(getattr(actual_subclass, field), getattr(expected_subclass, field)): + for a, b in zip(getattr(actual_subclass, field), getattr(expected_subclass, field), strict=False): assert (a == b).all() for field in ratapi.outputs.bayes_results_fields["double_list_fields"][subclass]: @@ -98,7 +98,7 @@ def check_bayes_fields_equal(actual_results, expected_results) -> None: expected_list = getattr(expected_subclass, field) assert len(actual_list) == len(expected_list) for i in range(len(actual_list)): - for a, b in zip(actual_list[i], expected_list[i]): + for a, b in zip(actual_list[i], expected_list[i], strict=False): assert (a == b).all() # Need to account for the arrays that are initialised as "NaN" in the compiled code From 3345860fc6c1bdd443cdeaa3e04f122657417fe3 Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:59:47 +0100 Subject: [PATCH 3/4] Fixes GLIBC Linux wheel issue and adds more progress callbacks (#187) * Change MANYLINUX_X86_64_IMAGE to fix issue on centos 8 * Add progress callback for panel plot helper --- .github/workflows/build_wheel.yml | 1 + ratapi/utils/plotting.py | 41 ++++++++++++++++++++----------- 2 files changed, 28 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build_wheel.yml b/.github/workflows/build_wheel.yml index 3fa3f319..1a966d3d 100644 --- a/.github/workflows/build_wheel.yml +++ b/.github/workflows/build_wheel.yml @@ -15,6 +15,7 @@ jobs: env: CIBW_SKIP: 'pp*' CIBW_ARCHS: 'auto64' + CIBW_MANYLINUX_X86_64_IMAGE: 'manylinux_2_28' CIBW_PROJECT_REQUIRES_PYTHON: '>=3.10' CIBW_TEST_REQUIRES: 'pytest' defaults: diff --git a/ratapi/utils/plotting.py b/ratapi/utils/plotting.py index 50e0d448..30ab8947 100644 --- a/ratapi/utils/plotting.py +++ b/ratapi/utils/plotting.py @@ -982,7 +982,10 @@ def plot_contour( def panel_plot_helper( - plot_func: Callable, indices: list[int], fig: matplotlib.figure.Figure | None = None + plot_func: Callable, + indices: list[int], + fig: matplotlib.figure.Figure | None = None, + progress_callback: Callable[[int, int], None] | None = None, ) -> matplotlib.figure.Figure: """Generate a panel-based plot from a single plot function. @@ -994,6 +997,9 @@ def panel_plot_helper( The list of indices to pass into ``plot_func``. fig : matplotlib.figure.Figure, optional The figure object to use for plot. + progress_callback: Union[Callable[[int, int], None], None] + Callback function for providing progress during plot creation + First argument is current completed sub plot and second is total number of sub plots Returns ------- @@ -1005,21 +1011,19 @@ def panel_plot_helper( nrows, ncols = ceil(sqrt(nplots)), round(sqrt(nplots)) if fig is None: - fig = plt.subplots(nrows, ncols, figsize=(11, 10))[0] + fig = plt.subplots(nrows, ncols, figsize=(11, 10), subplot_kw={"visible": False})[0] else: fig.clf() - fig.subplots(nrows, ncols) + fig.subplots(nrows, ncols, subplot_kw={"visible": False}) axs = fig.get_axes() - - for plot_num, index in enumerate(indices): - axs[plot_num].tick_params(which="both", labelsize="medium") - axs[plot_num].xaxis.offsetText.set_fontsize("small") - axs[plot_num].yaxis.offsetText.set_fontsize("small") - plot_func(axs[plot_num], index) - - # blank unused plots - for i in range(nplots, len(axs)): - axs[i].set_visible(False) + for index, plot_num in enumerate(indices): + axs[index].tick_params(which="both", labelsize="medium") + axs[index].xaxis.offsetText.set_fontsize("small") + axs[index].yaxis.offsetText.set_fontsize("small") + axs[index].set_visible(True) + plot_func(axs[index], plot_num) + if progress_callback is not None: + progress_callback(index, nplots) fig.tight_layout() return fig @@ -1036,6 +1040,7 @@ def plot_hists( block: bool = False, fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, + progress_callback: Callable[[int, int], None] | None = None, **hist_settings, ): """Plot marginalised posteriors for several parameters from a Bayesian analysis. @@ -1072,6 +1077,9 @@ def plot_hists( The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. + progress_callback: Union[Callable[[int, int], None], None] + Callback function for providing progress during plot creation + First argument is current completed sub plot and second is total number of sub plots hist_settings : Settings passed to `np.histogram`. By default, the settings passed are `bins = 25` and `density = True`. @@ -1130,6 +1138,7 @@ def validate_dens_type(dens_type: str | None, param: str): ), params, fig, + progress_callback, ) if return_fig: return fig @@ -1144,6 +1153,7 @@ def plot_chain( block: bool = False, fig: matplotlib.figure.Figure | None = None, return_fig: bool = False, + progress_callback: Callable[[int, int], None] | None = None, ): """Plot the MCMC chain for each parameter of a Bayesian analysis. @@ -1162,6 +1172,9 @@ def plot_chain( The figure object to use for plot. return_fig: bool, default False If True, return the figure as an object instead of showing it. + progress_callback: Union[Callable[[int, int], None], None] + Callback function for providing progress during plot creation + First argument is current completed sub plot and second is total number of sub plots Returns ------- @@ -1187,7 +1200,7 @@ def plot_one_chain(axes: Axes, i: int): axes.plot(range(0, nsimulations, skip), chain[:, i][0:nsimulations:skip]) axes.set_title(results.fitNames[i], fontsize="small") - fig = panel_plot_helper(plot_one_chain, params, fig=fig) + fig = panel_plot_helper(plot_one_chain, params, fig, progress_callback) if return_fig: return fig plt.show(block=block) From 58d3aa6d96572087f6198bdf94e1c9b0375f6f6c Mon Sep 17 00:00:00 2001 From: StephenNneji <34302892+StephenNneji@users.noreply.github.com> Date: Fri, 10 Oct 2025 17:55:56 +0100 Subject: [PATCH 4/4] Updates version to 0.0.0.dev9 (#188) --- pyproject.toml | 2 +- setup.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 54a6c8d8..99c65b3a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = 'setuptools.build_meta' [project] name = "ratapi" -version = "0.0.0.dev8" +version = "0.0.0.dev9" description = "Python extension for the Reflectivity Analysis Toolbox (RAT)" readme = "README.md" requires-python = ">=3.10" diff --git a/setup.py b/setup.py index 3c21a01f..3bdfea9a 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,6 @@ from setuptools.command.build_clib import build_clib from setuptools.command.build_ext import build_ext -__version__ = "0.0.0.dev8" PACKAGE_NAME = "ratapi" with open("README.md") as f: @@ -152,7 +151,6 @@ def build_libraries(self, libraries): setup( name=PACKAGE_NAME, - version=__version__, author="", author_email="", url="https://github.com/RascalSoftware/python-RAT",