From 6a3943374904b39769d8c6fcb3923b6456814d19 Mon Sep 17 00:00:00 2001 From: Otto Moerbeek Date: Mon, 17 Jun 2024 14:58:01 +0200 Subject: [PATCH 1/2] rec: fix TCP case for cached policy tags (cherry picked from commit a7f8db9e9259dfe08e47959a6613f80b971ea535) --- pdns/recursordist/rec-main.hh | 2 +- pdns/recursordist/rec-tcp.cc | 11 +- .../test_Protobuf.py | 142 +++++++++++++++++- 3 files changed, 151 insertions(+), 4 deletions(-) diff --git a/pdns/recursordist/rec-main.hh b/pdns/recursordist/rec-main.hh index 9fc1d96bc8f0..6c7910df5929 100644 --- a/pdns/recursordist/rec-main.hh +++ b/pdns/recursordist/rec-main.hh @@ -125,7 +125,7 @@ struct DNSComboWriter }; std::string d_query; std::unordered_set d_policyTags; - const std::unordered_set d_gettagPolicyTags; + std::unordered_set d_gettagPolicyTags; std::string d_routingTag; std::vector d_records; diff --git a/pdns/recursordist/rec-tcp.cc b/pdns/recursordist/rec-tcp.cc index 7f7c4bef5b94..e74c6f9e2185 100644 --- a/pdns/recursordist/rec-tcp.cc +++ b/pdns/recursordist/rec-tcp.cc @@ -327,16 +327,23 @@ static void doProcessTCPQuestion(std::unique_ptr& comboWriter, s if (t_pdl) { try { if (t_pdl->d_gettag_ffi) { - RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_policyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta); + RecursorLua4::FFIParams params(qname, qtype, comboWriter->d_destination, comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_data, comboWriter->d_gettagPolicyTags, comboWriter->d_records, ednsOptions, comboWriter->d_proxyProtocolValues, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_rcode, comboWriter->d_ttlCap, comboWriter->d_variable, true, logQuery, comboWriter->d_logResponse, comboWriter->d_followCNAMERecords, comboWriter->d_extendedErrorCode, comboWriter->d_extendedErrorExtra, comboWriter->d_responsePaddingDisabled, comboWriter->d_meta); comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI); comboWriter->d_tag = t_pdl->gettag_ffi(params); comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTagFFI, comboWriter->d_tag, false); } else if (t_pdl->d_gettag) { comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag); - comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_policyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues); + comboWriter->d_tag = t_pdl->gettag(comboWriter->d_source, comboWriter->d_ednssubnet.source, comboWriter->d_destination, qname, qtype, &comboWriter->d_gettagPolicyTags, comboWriter->d_data, ednsOptions, true, requestorId, deviceId, deviceName, comboWriter->d_routingTag, comboWriter->d_proxyProtocolValues); comboWriter->d_eventTrace.add(RecEventTrace::LuaGetTag, comboWriter->d_tag, false); } + // Copy d_gettagPolicyTags to d_policyTags, so other Lua hooks see them and can add their + // own. Before storing into the packetcache, the tags in d_gettagPolicyTags will be + // cleared by addPolicyTagsToPBMessageIfNeeded() so they do *not* end up in the PC. When an + // Protobuf message is constructed, one part comes from the PC (including the tags + // set by non-gettag hooks), and the tags in d_gettagPolicyTags will be added by the code + // constructing the PB message. + comboWriter->d_policyTags = comboWriter->d_gettagPolicyTags; } catch (const std::exception& e) { if (g_logCommonErrors) { diff --git a/regression-tests.recursor-dnssec/test_Protobuf.py b/regression-tests.recursor-dnssec/test_Protobuf.py index 953a9ce20ead..2af5114d9f23 100644 --- a/regression-tests.recursor-dnssec/test_Protobuf.py +++ b/regression-tests.recursor-dnssec/test_Protobuf.py @@ -302,6 +302,7 @@ def generateRecursorConfig(cls, confdir): @ 3600 IN SOA {soa} a 3600 IN A 192.0.2.42 tagged 3600 IN A 192.0.2.84 +taggedtcp 3600 IN A 192.0.2.87 meta 3600 IN A 192.0.2.85 query-selected 3600 IN A 192.0.2.84 answer-selected 3600 IN A 192.0.2.84 @@ -994,7 +995,7 @@ class ProtobufTagCacheTest(TestRecursorProtobuf): """ % (protobufServersParameters[0].port, protobufServersParameters[1].port) _lua_dns_script_file = """ function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp) - if qname:equal('tagged.example.') then + if qname:equal('tagged.example.') or qname:equal('taggedtcp.example.') then return 0, { '' .. math.random() } end return 0 @@ -1036,6 +1037,145 @@ def testTagged(self): ts2 = msg.response.tags[0] self.assertNotEqual(ts1, ts2) + def testTaggedTCP(self): + name = 'taggedtcp.example.' + expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87') + query = dns.message.make_query(name, 'A', want_dnssec=True) + query.flags |= dns.flags.CD + res = self.sendTCPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') + self.checkNoRemainingMessage() + print(msg.response) + self.assertEqual(len(msg.response.tags), 1) + ts1 = msg.response.tags[0] + + # Again to check PC case + res = self.sendTCPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) + print(msg.response) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # time may have passed, so do not check TTL + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts2 = msg.response.tags[0] + self.assertNotEqual(ts1, ts2) + +class ProtobufTagCacheFFITest(TestRecursorProtobuf): + """ + This test makes sure that we correctly cache tags (actually not cache them) for the FFI case + """ + + _confdir = 'ProtobufTagCacheFFI' + _config_template = """ +auth-zones=example=configs/%s/example.zone""" % _confdir + _lua_config_file = """ + protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } ) + """ % (protobufServersParameters[0].port, protobufServersParameters[1].port) + _lua_dns_script_file = """ + local ffi = require("ffi") + + ffi.cdef[[ + typedef struct pdns_ffi_param pdns_ffi_param_t; + + const char* pdns_ffi_param_get_qname(pdns_ffi_param_t* ref); + void pdns_ffi_param_add_policytag(pdns_ffi_param_t* ref, const char* name); + ]] + + function gettag_ffi(obj) + qname = ffi.string(ffi.C.pdns_ffi_param_get_qname(obj)) + if qname == 'tagged.example' or qname == 'taggedtcp.example' then + ffi.C.pdns_ffi_param_add_policytag(obj, '' .. math.random()) + end + return 0 + end + """ + + def testTagged(self): + name = 'tagged.example.' + expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84') + query = dns.message.make_query(name, 'A', want_dnssec=True) + query.flags |= dns.flags.CD + res = self.sendUDPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts1 = msg.response.tags[0] + + # Again to check PC case + res = self.sendUDPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # time may have passed, so do not check TTL + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts2 = msg.response.tags[0] + self.assertNotEqual(ts1, ts2) + + def testTaggedTCP(self): + name = 'taggedtcp.example.' + expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87') + query = dns.message.make_query(name, 'A', want_dnssec=True) + query.flags |= dns.flags.CD + res = self.sendTCPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # we have max-cache-ttl set to 15 + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') + self.checkNoRemainingMessage() + print(msg.response) + self.assertEqual(len(msg.response.tags), 1) + ts1 = msg.response.tags[0] + + # Again to check PC case + res = self.sendTCPQuery(query) + self.assertRRsetInAnswer(res, expected) + + msg = self.getFirstProtobufMessage() + self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) + print(msg.response) + self.assertEqual(len(msg.response.rrs), 1) + rr = msg.response.rrs[0] + # time may have passed, so do not check TTL + self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False) + self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') + self.checkNoRemainingMessage() + self.assertEqual(len(msg.response.tags), 1) + ts2 = msg.response.tags[0] + self.assertNotEqual(ts1, ts2) + class ProtobufSelectedFromLuaTest(TestRecursorProtobuf): """ This test makes sure that we correctly export queries and responses but only if they have been selected from Lua. From 9bddb82fd4166c83b055f037f775f33471403704 Mon Sep 17 00:00:00 2001 From: Otto Moerbeek Date: Tue, 18 Jun 2024 10:35:08 +0200 Subject: [PATCH 2/2] Refactor test to avoid code duplciation, as suggested by @rgacogne (cherry picked from commit 3aebfacee518cf32c07efb53e70317a4b2a4019a) --- .../test_Protobuf.py | 117 ++++-------------- 1 file changed, 25 insertions(+), 92 deletions(-) diff --git a/regression-tests.recursor-dnssec/test_Protobuf.py b/regression-tests.recursor-dnssec/test_Protobuf.py index 2af5114d9f23..8e7a72900d06 100644 --- a/regression-tests.recursor-dnssec/test_Protobuf.py +++ b/regression-tests.recursor-dnssec/test_Protobuf.py @@ -982,25 +982,8 @@ def testTagged(self): self.checkProtobufTags(msg, tags) self.checkNoRemainingMessage() -class ProtobufTagCacheTest(TestRecursorProtobuf): - """ - This test makes sure that we correctly cache tags (actually not cache them) - """ - - _confdir = 'ProtobufTagCache' - _config_template = """ -auth-zones=example=configs/%s/example.zone""" % _confdir - _lua_config_file = """ - protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } ) - """ % (protobufServersParameters[0].port, protobufServersParameters[1].port) - _lua_dns_script_file = """ - function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp) - if qname:equal('tagged.example.') or qname:equal('taggedtcp.example.') then - return 0, { '' .. math.random() } - end - return 0 - end - """ +class ProtobufTagCacheBase(TestRecursorProtobuf): + __test__ = False def testTagged(self): name = 'tagged.example.' @@ -1074,11 +1057,33 @@ def testTaggedTCP(self): ts2 = msg.response.tags[0] self.assertNotEqual(ts1, ts2) -class ProtobufTagCacheFFITest(TestRecursorProtobuf): +class ProtobufTagCacheTest(ProtobufTagCacheBase): + """ + This test makes sure that we correctly cache tags (actually not cache them) + """ + + __test__ = True + _confdir = 'ProtobufTagCache' + _config_template = """ +auth-zones=example=configs/%s/example.zone""" % _confdir + _lua_config_file = """ + protobufServer({"127.0.0.1:%d", "127.0.0.1:%d"}, { logQueries=false, logResponses=true } ) + """ % (protobufServersParameters[0].port, protobufServersParameters[1].port) + _lua_dns_script_file = """ + function gettag(remote, ednssubnet, localip, qname, qtype, ednsoptions, tcp) + if qname:equal('tagged.example.') or qname:equal('taggedtcp.example.') then + return 0, { '' .. math.random() } + end + return 0 + end + """ + +class ProtobufTagCacheFFITest(ProtobufTagCacheBase): """ This test makes sure that we correctly cache tags (actually not cache them) for the FFI case """ + __test__ = True _confdir = 'ProtobufTagCacheFFI' _config_template = """ auth-zones=example=configs/%s/example.zone""" % _confdir @@ -1104,78 +1109,6 @@ class ProtobufTagCacheFFITest(TestRecursorProtobuf): end """ - def testTagged(self): - name = 'tagged.example.' - expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.84') - query = dns.message.make_query(name, 'A', want_dnssec=True) - query.flags |= dns.flags.CD - res = self.sendUDPQuery(query) - self.assertRRsetInAnswer(res, expected) - - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) - self.assertEqual(len(msg.response.rrs), 1) - rr = msg.response.rrs[0] - # we have max-cache-ttl set to 15 - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) - self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') - self.checkNoRemainingMessage() - self.assertEqual(len(msg.response.tags), 1) - ts1 = msg.response.tags[0] - - # Again to check PC case - res = self.sendUDPQuery(query) - self.assertRRsetInAnswer(res, expected) - - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.UDP, res) - self.assertEqual(len(msg.response.rrs), 1) - rr = msg.response.rrs[0] - # time may have passed, so do not check TTL - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False) - self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.84') - self.checkNoRemainingMessage() - self.assertEqual(len(msg.response.tags), 1) - ts2 = msg.response.tags[0] - self.assertNotEqual(ts1, ts2) - - def testTaggedTCP(self): - name = 'taggedtcp.example.' - expected = dns.rrset.from_text(name, 0, dns.rdataclass.IN, 'A', '192.0.2.87') - query = dns.message.make_query(name, 'A', want_dnssec=True) - query.flags |= dns.flags.CD - res = self.sendTCPQuery(query) - self.assertRRsetInAnswer(res, expected) - - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) - self.assertEqual(len(msg.response.rrs), 1) - rr = msg.response.rrs[0] - # we have max-cache-ttl set to 15 - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15) - self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') - self.checkNoRemainingMessage() - print(msg.response) - self.assertEqual(len(msg.response.tags), 1) - ts1 = msg.response.tags[0] - - # Again to check PC case - res = self.sendTCPQuery(query) - self.assertRRsetInAnswer(res, expected) - - msg = self.getFirstProtobufMessage() - self.checkProtobufResponse(msg, dnsmessage_pb2.PBDNSMessage.TCP, res) - print(msg.response) - self.assertEqual(len(msg.response.rrs), 1) - rr = msg.response.rrs[0] - # time may have passed, so do not check TTL - self.checkProtobufResponseRecord(rr, dns.rdataclass.IN, dns.rdatatype.A, name, 15, checkTTL=False) - self.assertEqual(socket.inet_ntop(socket.AF_INET, rr.rdata), '192.0.2.87') - self.checkNoRemainingMessage() - self.assertEqual(len(msg.response.tags), 1) - ts2 = msg.response.tags[0] - self.assertNotEqual(ts1, ts2) - class ProtobufSelectedFromLuaTest(TestRecursorProtobuf): """ This test makes sure that we correctly export queries and responses but only if they have been selected from Lua.