dns-probe/dns_probe/__init__.py
2025-12-14 19:26:49 +01:00

354 lines
12 KiB
Python

import argparse
import itertools
import logging
import dns.query
import dns.name
import dns.message
import dns.resolver
import dns.rdatatype
import dns.exception
import dns.dnssec
from prometheus_client import CollectorRegistry, generate_latest, CONTENT_TYPE_LATEST
from prometheus_client.core import GaugeMetricFamily
from wsgiref.simple_server import make_server
from pyramid.config import Configurator
from pyramid.response import Response
from pyramid.httpexceptions import HTTPBadRequest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__version__ = '0.4.2'
class DnsCollector(object):
def __init__(self, zone, nameservers=[], ipv4=True, ipv6=True, query_timeout=2):
self.zone = zone
self.nameservers = nameservers
self.ipv4 = ipv4
self.ipv6 = ipv6
self.query_timeout = query_timeout
self.ds = []
self.keys = {}
self.keys_rrsig = {}
self.ns_resolve_sucess_metrics = GaugeMetricFamily(
'dns_probe_resolve_nameservers_success',
'Probe sucessfully managed to fetch the list of ns',
labels=['zone']
)
self.soa_serial_metrics = GaugeMetricFamily(
'dns_probe_soa_serial',
'Serial of SOA',
labels=['zone', 'nameserver']
)
self.rrsig_expiration_metrics = GaugeMetricFamily(
'dns_probe_soa_rrsig_expiration',
'Expiration date of DNSSEC signature',
labels=['zone', 'nameserver', 'keytag']
)
self.ns_set_metrics = GaugeMetricFamily(
'dns_probe_ns_set',
'List of nameservers',
labels=['zone', 'target', 'nameserver']
)
self.query_success_metrics = GaugeMetricFamily(
'dns_probe_query_success',
'Status of DNS query to nameserver',
labels=['name', 'type', 'nameserver']
)
self.dnssec_delegation_status_metrics = GaugeMetricFamily(
'dns_probe_dnssec_delegation_status',
'Status of DNSSEC delegation: secure, unsecure, bogus',
labels=['zone', 'status']
)
self.dnssec_ds_has_key_metrics = GaugeMetricFamily(
'dns_probe_dnssec_ds_has_key',
'List of DS records, 1 if has a corresponding DNSKEY in the zone, 0 if not',
labels=['zone', 'keytag', 'digest']
)
self.dnssec_dnskey_signature_status_metrics = GaugeMetricFamily(
'dns_probe_dnssec_dnskey_signature_status',
'Status of DNSKEY rrset signature',
labels=['zone', 'nameserver']
)
def fetch_ds(self):
self.ds = list(self.resolve(self.zone, dns.rdatatype.DS, raise_on_no_answer=False))
def fetch_ns(self):
self.nameservers = [
addr
for ns in self.resolve(self.zone, dns.rdatatype.NS)
for addr in self.resolve_addr(ns.target)
]
def resolve_addr(self, qname):
a_records = aaaa_records = ()
if self.ipv4:
a_records = self.resolve(qname, dns.rdatatype.A)
if self.ipv6:
aaaa_records = self.resolve(qname, dns.rdatatype.AAAA)
return (addr.address for addr in itertools.chain(a_records, aaaa_records))
def resolve(self, qname, qtype, raise_on_no_answer=True):
return itertools.chain(
*(
rrset.items.keys()
for rrset in dns.resolver.resolve(qname, qtype, raise_on_no_answer=raise_on_no_answer).response.answer
if rrset.rdtype == qtype
)
)
def query(self, qname, qtype, ns, dnssec=False):
try:
res, _is_tcp = dns.query.udp_with_fallback(
dns.message.make_query(qname, qtype, want_dnssec=dnssec),
ns,
timeout=self.query_timeout
)
except (dns.exception.DNSException, OSError):
self.query_success_metrics.add_metric([qname, qtype, ns], 0)
raise
else:
self.query_success_metrics.add_metric([qname, qtype, ns], 1)
return res
def check_soa(self, ns):
res_soa = self.query(self.zone, 'SOA', ns, dnssec=True)
soa = None
rrsig_set = []
for rrset in res_soa.answer:
if rrset.rdtype == dns.rdatatype.SOA:
soa = list(rrset.items.keys())[0]
if rrset.rdtype == dns.rdatatype.RRSIG:
rrsig_set = list(rrset.items.keys())
self.soa_serial_metrics.add_metric([self.zone, ns], soa.serial)
for rrsig in rrsig_set:
self.rrsig_expiration_metrics.add_metric([self.zone, ns, str(rrsig.key_tag)], rrsig.expiration)
def check_dnssec_delegation(self):
if not self.ds:
self.dnssec_delegation_status_metrics.add_metric([self.zone, 'unsecure'], 1)
return
ds_with_key = set()
ds_without_key = {}
trusted_keys = {}
key_flags = dns.dnssec.Flag.SEP | dns.dnssec.Flag.ZONE
for ns in self.nameservers:
ds_without_key[ns] = set()
trusted_keys[ns] = set()
for ds in self.ds:
has_key = False
for key in self.keys[ns]:
if (key.flags & key_flags) != key_flags:
continue
key_ds = dns.dnssec.make_ds(
self.zone,
key,
ds.digest_type,
policy=dns.dnssec.allow_all_policy
)
if ds == key_ds:
ds_with_key.add(ds)
trusted_keys[ns].add(key)
has_key = True
break
if not has_key:
ds_without_key[ns].add(ds)
for ds in set.intersection(*ds_without_key.values()):
self.dnssec_ds_has_key_metrics.add_metric([self.zone, str(ds.key_tag), ds.digest.hex()], 0)
for ds in ds_with_key:
self.dnssec_ds_has_key_metrics.add_metric([self.zone, str(ds.key_tag), ds.digest.hex()], 1)
if not ds_with_key:
self.dnssec_delegation_status_metrics.add_metric([self.zone, 'bogus'], 1)
return
has_valid_signature = {}
domain = dns.name.from_text(self.zone)
for ns in self.nameservers:
ns_trusted_keys = trusted_keys.get(ns, [])
has_valid_signature[ns] = False
for rrsig in self.keys_rrsig[ns]:
try:
dns.dnssec.validate_rrsig(self.keys[ns], rrsig, {domain: ns_trusted_keys})
has_valid_signature[ns] = True
break
except (dns.dnssec.ValidationFailure, dns.dnssec.UnsupportedAlgorithm):
pass
self.dnssec_dnskey_signature_status_metrics.add_metric([self.zone, ns], int(has_valid_signature[ns]))
if all(has_valid_signature.values()):
self.dnssec_delegation_status_metrics.add_metric([self.zone, 'secure'], 1)
else:
self.dnssec_delegation_status_metrics.add_metric([self.zone, 'bogus'], 1)
def list_ns(self, ns):
res_ns = self.query(self.zone, 'NS', ns)
ns_set = []
for rrset in res_ns.answer:
if rrset.rdtype == dns.rdatatype.NS:
ns_set = list(rrset.items.keys())
for ns_record in ns_set:
target = ns_record.target.to_text()
self.ns_set_metrics.add_metric([self.zone, target, ns], 1)
def list_dnskey(self, ns):
res_ns = self.query(self.zone, 'DNSKEY', ns, dnssec=True)
for rrset in res_ns.answer:
if rrset.rdtype == dns.rdatatype.DNSKEY:
self.keys[ns] = rrset
if rrset.rdtype == dns.rdatatype.RRSIG:
self.keys_rrsig[ns] = rrset
def collect(self):
if not self.nameservers:
try:
self.fetch_ns()
except dns.exception.Timeout:
logger.error(f'Timeout while querying for NS for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to get fetch nameservers for zone {self.zone}')
self.ns_resolve_sucess_metrics.add_metric([self.zone], 0)
else:
self.ns_resolve_sucess_metrics.add_metric([self.zone], 1)
try:
self.fetch_ds()
except dns.exception.Timeout:
logger.warning(f'Timeout while querying for DS for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to fetch DS records for zone {self.zone}')
for ns in self.nameservers:
try:
self.list_ns(ns)
except dns.exception.Timeout:
logger.warning(f'NS {ns} timeout while querying for NS for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to list NS from nameserver {ns} for zone {self.zone}')
try:
self.list_dnskey(ns)
except dns.exception.Timeout:
logger.warning(f'NS {ns} timeout while querying for DNSKEY for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to list DNSKEY from nameserver {ns} for zone {self.zone}')
try:
self.check_soa(ns)
except dns.exception.Timeout:
logger.warning(f'NS {ns} timeout while querying for SOA for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to get SOA metrics from nameserver {ns} for zone {self.zone}')
self.check_dnssec_delegation()
yield self.ns_resolve_sucess_metrics
yield self.ns_set_metrics
yield self.soa_serial_metrics
yield self.rrsig_expiration_metrics
yield self.query_success_metrics
yield self.dnssec_ds_has_key_metrics
yield self.dnssec_delegation_status_metrics
yield self.dnssec_dnskey_signature_status_metrics
def parse_bool(val, name, default):
if val is None:
return default
if val.lower() in ('true', '1', 'enabled', 'yes'):
return True
if val.lower() in ('false', '0', 'disabled', 'no'):
return False
raise HTTPBadRequest(f'Unknown value for param {name}, allowed values: true, 1, enabled, yes, false, 0, disabled, no')
def parse_float(val, name, default):
if val is None:
return default
try:
return float(val)
except ValueError:
raise HTTPBadRequest(f'Could not convert the value of {name} to float')
def probe_view(request):
zone = request.params.get('zone')
nameservers = request.params.getall('nameservers[]')
ipv4 = parse_bool(request.params.get('ipv4'), 'ipv4', True)
ipv6 = parse_bool(request.params.get('ipv6'), 'ipv6', True)
query_timeout = parse_float(request.params.get('query_timeout'), 'query_timeout', 2)
if zone is not None:
if not zone.endswith('.'):
zone += '.'
registry = CollectorRegistry()
registry.register(DnsCollector(
zone,
nameservers=nameservers,
ipv4=ipv4,
ipv6=ipv6,
query_timeout=query_timeout,
))
data = generate_latest(registry)
return Response(data, content_type=CONTENT_TYPE_LATEST)
else:
raise HTTPBadRequest('zone parameter is required')
def make_app():
with Configurator() as config:
config.add_route('probe', '/probe')
config.add_view(probe_view, route_name='probe', request_method='GET')
app = config.make_wsgi_app()
return app
def serve(ip='127.0.0.1', port=8953):
web_server = make_server(ip, port, make_app())
logger.info(f'Starting webserver on {ip}:{port}')
web_server.serve_forever()
def parse_listen(listen_str):
ip, _sep, port = listen_str.rpartition(':')
if ip == '':
ip = '0.0.0.0'
if port == '':
raise ValueError('Port can not be empty')
return {'ip': ip, 'port': int(port)}
def main():
parser = argparse.ArgumentParser(description='DNS probe that exports Prometheus-like data')
parser.add_argument('-l', '--listen', default='127.0.0.1:8953', help='Address to listen to, default %(default)s', type=parse_listen)
args = parser.parse_args()
serve(**args.listen)
if __name__ == '__main__':
main()