diff --git a/src/main/java/org/xbill/DNS/Resolver.java b/src/main/java/org/xbill/DNS/Resolver.java index 5d3c0ff3..5b6ce511 100644 --- a/src/main/java/org/xbill/DNS/Resolver.java +++ b/src/main/java/org/xbill/DNS/Resolver.java @@ -9,6 +9,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -239,11 +240,16 @@ default Object sendAsync(Message query, ResolverListener listener) { (result, throwable) -> { if (throwable != null) { Exception exception; + if (throwable instanceof CompletionException && throwable.getCause() != null) { + throwable = throwable.getCause(); + } + if (throwable instanceof Exception) { exception = (Exception) throwable; } else { exception = new Exception(throwable); } + listener.handleException(id, exception); return null; } diff --git a/src/test/java/org/xbill/DNS/ResolverTest.java b/src/test/java/org/xbill/DNS/ResolverTest.java new file mode 100644 index 00000000..365254a8 --- /dev/null +++ b/src/test/java/org/xbill/DNS/ResolverTest.java @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: BSD-3-Clause +package org.xbill.DNS; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.fail; + +import java.net.UnknownHostException; +import java.time.Duration; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +class ResolverTest { + @Test + @SuppressWarnings("deprecation") + void resolverListenerExceptionUnwrap() throws InterruptedException, UnknownHostException { + // 1. Point to a blackhole address from RFC 5737 TEST-NET-1 to ensure a timeout + SimpleResolver resolver = new SimpleResolver("192.0.2.1"); + resolver.setTimeout(Duration.ofSeconds(2)); + + Message query = + Message.newQuery( + Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN)); + CountDownLatch latch = new CountDownLatch(1); + + // 2. Use the async method with a listener + resolver.sendAsync( + query, + new ResolverListener() { + @Override + public void receiveMessage(Object id, Message m) { + fail("Received message (should not happen)"); + latch.countDown(); + } + + @Override + public void handleException(Object id, Exception ex) { + // 3. Observe the exception type + assertThat(ex).isNotInstanceOf(CompletionException.class); + latch.countDown(); + } + }); + + latch.await(5, TimeUnit.SECONDS); + } +}