From fccbe7d00f6752974789110ba75856fd8143e828 Mon Sep 17 00:00:00 2001 From: Ingo Bauersachs Date: Sat, 28 Jun 2025 18:15:14 +0200 Subject: [PATCH] Fix section size in header after message normalization Fixes: 2073a0cdea2c560465f7ac0cc56f202e6fc39705 Closes #384 --- src/main/java/org/xbill/DNS/Message.java | 42 ++++++++++++------- src/test/java/org/xbill/DNS/MessageTest.java | 7 +++- src/test/java/org/xbill/DNS/TSIGTest.java | 14 +++++++ .../java/org/xbill/DNS/dnssec/TestBase.java | 14 ++++++- 4 files changed, 59 insertions(+), 18 deletions(-) diff --git a/src/main/java/org/xbill/DNS/Message.java b/src/main/java/org/xbill/DNS/Message.java index bb214c8d..767450b8 100644 --- a/src/main/java/org/xbill/DNS/Message.java +++ b/src/main/java/org/xbill/DNS/Message.java @@ -810,9 +810,11 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) List additionalSectionSets = getSectionRRsets(Section.ADDITIONAL); List authoritySectionSets = getSectionRRsets(Section.AUTHORITY); - List cleanedAnswerSection = new ArrayList<>(); - List cleanedAuthoritySection = new ArrayList<>(); - List cleanedAdditionalSection = new ArrayList<>(); + @SuppressWarnings("unchecked") + List[] cleanedSection = new ArrayList[4]; + cleanedSection[Section.ANSWER] = new ArrayList<>(); + cleanedSection[Section.AUTHORITY] = new ArrayList<>(); + cleanedSection[Section.ADDITIONAL] = new ArrayList<>(); boolean hadNsInAuthority = false; // For the ANSWER section, remove all "irrelevant" records and add synthesized CNAMEs from @@ -843,7 +845,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) // If DNAME was queried, don't attempt to synthesize CNAME if (query.getQuestion().getType() != Type.DNAME) { // The DNAME is valid, accept it - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); // Check if the next rrset is correct CNAME, otherwise synthesize a CNAME RRset nextRRSet = answerSectionSets.size() >= i + 2 ? answerSectionSets.get(i + 1) : null; @@ -863,7 +865,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) // Add a synthesized CNAME; TTL=0 to avoid caching Name dnameTarget = sname.fromDNAME(dname); - cleanedAnswerSection.add( + cleanedSection[Section.ANSWER].add( new RRset(new CNAMERecord(sname, dname.getDClass(), 0, dnameTarget))); sname = dnameTarget; @@ -872,7 +874,7 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) for (i++; i < answerSectionSets.size(); i++) { rrset = answerSectionSets.get(i); if (rrset.getName().equals(oldSname)) { - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); } else { break; } @@ -943,14 +945,14 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } sname = ((CNAMERecord) rrset.first()).getTarget(); - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); // In CNAME ANY response, can have data after CNAME if (query.getQuestion().getType() == Type.ANY) { for (i++; i < answerSectionSets.size(); i++) { rrset = answerSectionSets.get(i); if (rrset.getName().equals(oldSname)) { - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); } else { break; } @@ -973,9 +975,9 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } // Mark the additional names from relevant RRset as OK - cleanedAnswerSection.add(rrset); + cleanedSection[Section.ANSWER].add(rrset); if (sname.equals(rrset.getName())) { - addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection); + addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]); } } @@ -1045,15 +1047,25 @@ public Message normalize(Message query, boolean throwOnIrrelevantRecord) } } - cleanedAuthoritySection.add(rrset); - addAdditionalRRset(rrset, additionalSectionSets, cleanedAdditionalSection); + cleanedSection[Section.AUTHORITY].add(rrset); + addAdditionalRRset(rrset, additionalSectionSets, cleanedSection[Section.ADDITIONAL]); } Message cleanedMessage = new Message(this.getHeader()); cleanedMessage.sections[Section.QUESTION] = this.sections[Section.QUESTION]; - cleanedMessage.sections[Section.ANSWER] = rrsetListToRecords(cleanedAnswerSection); - cleanedMessage.sections[Section.AUTHORITY] = rrsetListToRecords(cleanedAuthoritySection); - cleanedMessage.sections[Section.ADDITIONAL] = rrsetListToRecords(cleanedAdditionalSection); + for (int section : new int[] {Section.ANSWER, Section.AUTHORITY, Section.ADDITIONAL}) { + cleanedMessage.sections[section] = rrsetListToRecords(cleanedSection[section]); + + // Fixup counts in the header + cleanedMessage + .getHeader() + .setCount( + section, + cleanedMessage.sections[section] == null + ? 0 + : cleanedMessage.sections[section].size()); + } + return cleanedMessage; } diff --git a/src/test/java/org/xbill/DNS/MessageTest.java b/src/test/java/org/xbill/DNS/MessageTest.java index 54e32d5c..48639a97 100644 --- a/src/test/java/org/xbill/DNS/MessageTest.java +++ b/src/test/java/org/xbill/DNS/MessageTest.java @@ -35,6 +35,7 @@ // package org.xbill.DNS; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -177,7 +178,9 @@ void normalize() throws WireParseException { response.addRecord(queryRecord, Section.QUESTION); response.addRecord(queryRecord, Section.ADDITIONAL); response = response.normalize(query, true); - assertTrue(response.getSection(Section.ANSWER).isEmpty()); - assertTrue(response.getSection(Section.ADDITIONAL).isEmpty()); + assertThat(response.getSection(Section.ANSWER)).isEmpty(); + assertThat(response.getHeader().getCount(Section.ANSWER)).isZero(); + assertThat(response.getSection(Section.ADDITIONAL)).isEmpty(); + assertThat(response.getHeader().getCount(Section.ADDITIONAL)).isZero(); } } diff --git a/src/test/java/org/xbill/DNS/TSIGTest.java b/src/test/java/org/xbill/DNS/TSIGTest.java index a83fb6e0..416e8747 100644 --- a/src/test/java/org/xbill/DNS/TSIGTest.java +++ b/src/test/java/org/xbill/DNS/TSIGTest.java @@ -2,6 +2,7 @@ package org.xbill.DNS; import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotEquals; @@ -520,6 +521,19 @@ TCPClient createTcpClient(Duration timeout) throws IOException { assertEquals(202, handler.getRecords().size()); } + @Test + void invalidAdditionalCount() { + Message q = Message.newQuery(Record.newRecord(Name.root, Type.A, DClass.IN)); + Message m = new Message(); + m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.QUESTION); + m.addRecord(Record.newRecord(Name.root, Type.A, DClass.IN), Section.ANSWER); + m.addRecord( + Record.newRecord(Name.fromConstantString("example.com."), Type.A, DClass.IN), + Section.ADDITIONAL); + assertDoesNotThrow(m::getTSIG); + assertDoesNotThrow(() -> m.normalize(q).getTSIG()); + } + @Getter private static class ZoneBuilderAxfrHandler implements ZoneTransferIn.ZoneTransferHandler { private final List records = new ArrayList<>(); diff --git a/src/test/java/org/xbill/DNS/dnssec/TestBase.java b/src/test/java/org/xbill/DNS/dnssec/TestBase.java index 60f53b2c..a5745c78 100644 --- a/src/test/java/org/xbill/DNS/dnssec/TestBase.java +++ b/src/test/java/org/xbill/DNS/dnssec/TestBase.java @@ -1,6 +1,7 @@ // SPDX-License-Identifier: BSD-3-Clause package org.xbill.DNS.dnssec; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.Mockito.mock; @@ -126,7 +127,18 @@ private void starting(TestInfo description) { Message m; while ((m = messageReader.readMessage(r)) != null) { + for (int i = 0; i < 4; i++) { + assertThat(m.getHeader().getCount(i)) + .withFailMessage("Before normalization") + .isEqualTo(m.getSection(i).size()); + } + m = m.normalize(Message.newQuery(m.getQuestion()), true); + for (int i = 0; i < 4; i++) { + assertThat(m.getHeader().getCount(i)) + .withFailMessage("After normalization") + .isEqualTo(m.getSection(i).size()); + } queryResponsePairs.put(key(m), m); } @@ -286,7 +298,7 @@ protected String getEdeText(Message m) { .flatMap( opt -> opt.getOptions(Code.EDNS_EXTENDED_ERROR).stream() - .filter(o -> o instanceof ExtendedErrorCodeOption) + .filter(ExtendedErrorCodeOption.class::isInstance) .findFirst() .map(o -> ((ExtendedErrorCodeOption) o).getText())) .orElse(null);