Source code for checkdmarc.utils

# -*- coding: utf-8 -*-
"""DNS utility functions"""

from __future__ import annotations

import logging
import dns
import dns.resolver
import re
from collections import OrderedDict

import publicsuffixlist
from expiringdict import ExpiringDict

"""Copyright 2019-2023 Sean Whalen

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

   https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License."""

DNS_CACHE = ExpiringDict(max_len=200000, max_age_seconds=1800)

WSP_REGEX = r"[ \t]"
HTTPS_REGEX = (
    r"https://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\(\),]|(?:%[0-9a-fA-F]"
    r"[0-9a-fA-F]))+"
)
MAILTO_REGEX_STRING = (
    r"^(mailto):([\w\-!#$%&'*+-/=?^_`{|}~]"
    r"[\w\-.!#$%&'*+-/=?^_`{|}~]*@[\w\-.]+)(!\w+)?"
)

MAILTO_REGEX = re.compile(MAILTO_REGEX_STRING, re.IGNORECASE)


[docs] class DNSException(Exception): """Raised when a general DNS error occurs""" def __init__(self, error): if isinstance(error, dns.exception.Timeout): error.kwargs["timeout"] = round(error.kwargs["timeout"], 1)
[docs] class DNSExceptionNXDOMAIN(DNSException): """Raised when a NXDOMAIN DNS error (RCODE:3) occurs"""
[docs] def get_base_domain(domain: str) -> str: """ Gets the base domain name for the given domain .. note:: Results are based on a list of public domain suffixes at https://publicsuffix.org/list/public_suffix_list.dat. Args: domain (str): A domain or subdomain Returns: str: The base domain of the given domain """ psl = publicsuffixlist.PublicSuffixList() domain = domain.lower() return psl.privatesuffix(domain) or domain
[docs] def query_dns(domain: str, record_type: str, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0, cache: ExpiringDict = None) -> list[str]: """ Queries DNS Args: domain (str): The domain or subdomain to query about record_type (str): The record type to query for nameservers (list): A list of one or more nameservers to use resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): Sets the DNS timeout in seconds cache (ExpiringDict): Cache storage Returns: list: A list of answers """ domain = domain.lower() record_type = record_type.upper() cache_key = f"{domain}_{record_type}" if cache is None: cache = DNS_CACHE if type(cache) is ExpiringDict: records = cache.get(cache_key) if records: return records if not resolver: resolver = dns.resolver.Resolver() timeout = float(timeout) if nameservers is not None: resolver.nameservers = nameservers resolver.timeout = timeout resolver.lifetime = timeout if record_type == "TXT": resource_records = list(map( lambda r: r.strings, resolver.resolve(domain, record_type, lifetime=timeout))) _resource_record = [ resource_record[0][:0].join(resource_record) for resource_record in resource_records if resource_record] records = [] for r in _resource_record: try: r = r.decode() except UnicodeDecodeError: pass records.append(r) else: records = list(map( lambda r: r.to_text().replace('"', '').rstrip("."), resolver.resolve(domain, record_type, lifetime=timeout))) if type(cache) is ExpiringDict: cache[cache_key] = records return records
[docs] def get_a_records(domain: str, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0) -> list[str]: """ Queries DNS for A and AAAA records Args: domain (str): A domain name nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS Returns: list: A sorted list of IPv4 and IPv6 addresses Raises: :exc:`checkdmarc.DNSException` """ qtypes = ["A", "AAAA"] addresses = [] for qt in qtypes: try: logging.debug(f"Getting {qt} records for {domain}") addresses += query_dns(domain, qt, nameservers=nameservers, resolver=resolver, timeout=timeout) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: # Sometimes a domain will only have A or AAAA records, but not both pass except Exception as error: raise DNSException(error) addresses = sorted(addresses) return addresses
[docs] def get_reverse_dns(ip_address: str, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0) -> list[str]: """ Queries for an IP addresses reverse DNS hostname(s) Args: ip_address (str): An IPv4 or IPv6 address nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS Returns: list: A list of reverse DNS hostnames Raises: :exc:`checkdmarc.DNSException` """ try: name = str(dns.reversename.from_address(ip_address)) logging.debug(f"Getting PTR records for {ip_address}") hostnames = query_dns(name, "PTR", nameservers=nameservers, resolver=resolver, timeout=timeout) except dns.resolver.NXDOMAIN: return [] except Exception as error: raise DNSException(error) return hostnames
[docs] def get_txt_records(domain: str, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0) -> list[str]: """ Queries DNS for TXT records Args: domain (str): A domain name nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS Returns: list: A list of TXT records Raises: :exc:`checkdmarc.DNSException` """ try: records = query_dns(domain, "TXT", nameservers=nameservers, resolver=resolver, timeout=timeout) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN(f"The domain {domain} does not exist") except dns.resolver.NoAnswer: raise DNSException( f"The domain {domain} does not have any TXT records") except Exception as error: raise DNSException(error) return records
[docs] def get_nameservers(domain: str, approved_nameservers: list[str] = None, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0) -> dict: """ Gets a list of nameservers for a given domain Args: domain (str): A domain name approved_nameservers (list): A list of approved nameserver substrings nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for a record from DNS Returns: OrderedDict: A dictionary with the following keys: - ``hostnames`` - A list of nameserver hostnames - ``warnings`` - A list of warnings """ logging.debug(f"Getting NS records on {domain}") warnings = [] ns_records = [] try: ns_records = query_dns(domain, "NS", nameservers=nameservers, resolver=resolver, timeout=timeout) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN( f"The domain {domain} does not exist") except dns.resolver.NoAnswer: pass except Exception as error: raise DNSException(error) if approved_nameservers: approved_nameservers = list(map(lambda h: h.lower(), approved_nameservers)) for nameserver in ns_records: if approved_nameservers: approved = False for approved_nameserver in approved_nameservers: if approved_nameserver in nameserver.lower(): approved = True break if not approved: warnings.append(f"Unapproved nameserver: {nameserver}") return OrderedDict([("hostnames", ns_records), ("warnings", warnings)])
[docs] def get_mx_records(domain: str, nameservers: list[str] = None, resolver: dns.resolver.Resolver = None, timeout: float = 2.0) -> list[OrderedDict]: """ Queries DNS for a list of Mail Exchange hosts Args: domain (str): A domain name nameservers (list): A list of nameservers to query resolver (dns.resolver.Resolver): A resolver object to use for DNS requests timeout (float): number of seconds to wait for an answer from DNS Returns: list: A list of ``OrderedDicts``; each containing a ``preference`` integer and a ``hostname`` Raises: :exc:`checkdmarc.DNSException` """ hosts = [] try: logging.debug(f"Checking for MX records on {domain}") answers = query_dns(domain, "MX", nameservers=nameservers, resolver=resolver, timeout=timeout) if answers == ['0 ']: logging.debug("\"No Service\" MX record found") return [] for record in answers: record = record.split(" ") preference = int(record[0]) hostname = record[1].rstrip(".").strip().lower() hosts.append(OrderedDict( [("preference", preference), ("hostname", hostname)])) hosts = sorted(hosts, key=lambda h: (h["preference"], h["hostname"])) except dns.resolver.NXDOMAIN: raise DNSExceptionNXDOMAIN( f"The domain {domain} does not exist") except dns.resolver.NoAnswer: pass except Exception as error: raise DNSException(error) return hosts