diff --git a/domaintools/base_results.py b/domaintools/base_results.py index 30d0735..ad00eed 100644 --- a/domaintools/base_results.py +++ b/domaintools/base_results.py @@ -25,7 +25,6 @@ RequestUriTooLongException, ) - try: # pragma: no cover from collections.abc import MutableMapping, MutableSequence except ImportError: # pragma: no cover @@ -111,7 +110,9 @@ def _make_request(self): if self.product == "iris-investigate" and "irisql" in self.kwargs: irisql_query = self.kwargs["irisql"] auth_keys = {"api_username", "timestamp", "signature", "api_key"} - query_params = {k: v for k, v in self.kwargs.items() if k != "irisql" and k not in auth_keys} + query_params = { + k: v for k, v in self.kwargs.items() if k != "irisql" and k not in auth_keys + } query_params.update(self.api.extra_request_params) return session.post( url=self.url, @@ -164,13 +165,13 @@ def data(self): self._data = results.json() else: self._data = results.text - self.check_limit_exceeded() return self._data def check_limit_exceeded(self): limit_exceeded, reason = False, "" + if isinstance(self._data, dict) and ( "response" in self._data and "limit_exceeded" in self._data["response"] @@ -178,7 +179,13 @@ def check_limit_exceeded(self): ): limit_exceeded, reason = True, self._data["response"]["message"] elif "response" in self._data and "limit_exceeded" in self._data: - limit_exceeded = True + # check for xml format, and return the actual error message + if self.kwargs.get("format") == "xml" and isinstance(self._data, str): + if re.search(r"1", self._data): + msg = re.search(r"(.*?)", self._data) + limit_exceeded, reason = True, msg.group(1) if msg else "" + else: + limit_exceeded = True if limit_exceeded: raise ServiceException(503, f"Limit Exceeded {reason}") @@ -354,9 +361,7 @@ def html(self): ) def as_list(self): - return "\n".join( - [json.dumps(item, indent=4, separators=(",", ": ")) for item in self._items()] - ) + return "\n".join([json.dumps(item, indent=4, separators=(",", ": ")) for item in self._items()]) def __str__(self): return str( diff --git a/tests/test_api.py b/tests/test_api.py index 6c8c2ef..8a05307 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -602,6 +602,34 @@ def test_limit_exceeded(): response.response() +def test_limit_exceeded_xml(): + xml_response = """ + + 413 + Maximum 10000 returned - you may need to refine your query. + + 1 + 1 + Maximum 10000 returned - you may need to refine your query. + + """ + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.text = xml_response + + with patch("domaintools.base_results.Client") as mock_client: + mock_session = MagicMock() + mock_client.return_value.__enter__.return_value = mock_session + mock_session.post.return_value = mock_response + + with pytest.raises(exceptions.ServiceException) as exc_info: + result = api.iris_investigate(ip="8.8.8.8", format="xml") + result.data() + + assert "Maximum 10000 returned" in str(exc_info.value) + + @vcr.use_cassette def test_newly_observed_domains_feed(): results = feeds_api.nod(after="-60", top=5)