import logging
import subprocess
from http.client import HTTPException
from urllib.error import URLError
from urllib.request import urlopen, Request

from dnscrypt import parse, DNSoverHTTPS

SMARTDNS_GFW_CONF_FILE = '/etc/smartdns/conf.d/gfw-server.conf'

# https://github.com/DNSCrypt/dnscrypt-resolvers
PUBLIC_RESOLVER_URL_LIST = [
    "https://download.dnscrypt.info/resolvers-list/v3/public-resolvers.md",
    "https://raw.githubusercontent.com/DNSCrypt/dnscrypt-resolvers/master/v3/public-resolvers.md"
    "https://dnsr.evilvibes.com/v3/public-resolvers.md"
]


def get_public_resolver_md() -> str:
    for url in PUBLIC_RESOLVER_URL_LIST:
        try:
            logging.info('request {url}'.format(url=url))
            with urlopen(url, timeout=15) as responsee:
                return responsee.read().decode('utf-8')
        except:
            pass
    raise IOError("can't download public-resolvers.md")


def get_stamps():
    resolver_md: str = get_public_resolver_md()
    lines = resolver_md.splitlines()
    stamps = list(
        map(
            parse,
            filter(lambda x: x.startswith('sdns://'), lines)
        )
    )
    return stamps


def get_not_china_doh_list():
    stamps = get_stamps()

    def is_match(s) -> bool:
        if isinstance(s, DNSoverHTTPS) is False:
            return False
        if s.nofilter is False:
            return False
        if s.dnssec is False:
            return False
        return True

    return list(filter(
        is_match,
        stamps
    ))


# def build_dns_query(domain, record_type):
#     # require: dnspython
#     import dns.message
#     dnsq = dns.message.make_query(
#         qname=domain,
#         rdtype=record_type,
#         want_dnssec=False,
#     )
#     return dnsq

TESTED_URL = set()


def doh_tester(hostname, path):
    # www.google.com A
    dnsq = b'\xf9\x04\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03www\x06google\x03com\x00\x00\x01\x00\x01'
    url = 'https://' + hostname + path
    if url in TESTED_URL:
        return False
    request = Request(url, data=dnsq)
    request.add_header('accept', 'application/dns-message')
    request.add_header('content-type', 'application/dns-message')
    try:
        with urlopen(request, timeout=3) as response:
            response.read()
            print('{url}\tok'.format(url=url))
            TESTED_URL.add(url)
            return True
    except (URLError, TimeoutError, ConnectionError, HTTPException) as e:
        print('{url}\tfailed\t{error}'.format(url=url, error=e))
        TESTED_URL.add(url)
        return False


def get_final_doh_list():
    stamps = get_not_china_doh_list()
    return list(filter(
        lambda x: doh_tester(x.hostname, x.path),
        stamps
    ))


def get_smartdns_config():
    stamps = get_final_doh_list()
    lines = set(map(
        lambda x: 'server-https https://' + x.hostname + x.path + ' -group GFW',
        stamps
    ))
    return '\n'.join(lines)


def write_smartdns_config():
    conf_txt = get_smartdns_config()
    with open(SMARTDNS_GFW_CONF_FILE, 'w') as f:
        f.write(conf_txt)


def reload_openwrt_smartdns():
    subprocess.run(["/etc/init.d/smartdns", "reload"])


def reload_pc_smartdns():
    subprocess.run(["systemctl", "restart", "smartdns.service"])


def run_openwrt():
    write_smartdns_config()
    reload_openwrt_smartdns()


def run_pc():
    write_smartdns_config()
    reload_pc_smartdns()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("where", choices=["openwrt", "pc"], help="运行环境:openwrt 或 pc")

    args = parser.parse_args()
    if args.where == "openwrt":
        run_openwrt()
    elif args.where == "pc":
        run_pc()