import base64
import json
import logging
import os.path
import re
import subprocess
import sys
from urllib.request import urlopen

PROXY_DNS_IP = '127.0.0.1'
PROXY_DNS_PORT = '5353'

DNSMASQ_RULES_FILE = '/tmp/dnsmasq.d/gfwlist.conf'
SMARTDNS_DOMAIN_SET_FILE = '/etc/smartdns/domain-set/gfwlist.conf'

# https://github.com/gfwlist/gfwlist
GFWLIST_URL_LIST = [
    "https://raw.githubusercontent.com/gfwlist/gfwlist/master/gfwlist.txt",
    "https://pagure.io/gfwlist/raw/master/f/gfwlist.txt",
    "https://gitlab.com/gfwlist/gfwlist/raw/master/gfwlist.txt",
    "https://git.tuxfamily.org/gfwlist/gfwlist.git/plain/gfwlist.txt",
    "http://repo.or.cz/gfwlist.git/blob_plain/HEAD:/gfwlist.txt"
]

PWD = os.path.dirname(os.path.realpath(__file__))


def get_gfwlist_text() -> str:
    for url in GFWLIST_URL_LIST:
        try:
            logging.info('request {url}'.format(url=url))
            with urlopen(url, timeout=15) as responsee:
                return base64.b64decode(responsee.read()).decode('utf-8')
        except:
            pass
    raise IOError("can't download gfwlist")


def is_comment(line: str) -> bool:
    comment_re = re.compile(r'^!|\[AutoProxy')
    return bool(comment_re.match(line))


def has_ip(line: str) -> bool:
    # https://stackoverflow.com/questions/5284147/validating-ipv4-addresses-with-regexp
    # https://stackoverflow.com/questions/53497/regular-expression-that-matches-valid-ipv6-addresses
    ipv4_re = re.compile(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}')
    ipv6_re = re.compile(
        r'(([0-9a-fA-F]{1,4}:){7,7}[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,7}:|([0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|([0-9a-fA-F]{1,4}:){1,5}(:[0-9a-fA-F]{1,4}){1,2}|([0-9a-fA-F]{1,4}:){1,4}(:[0-9a-fA-F]{1,4}){1,3}|([0-9a-fA-F]{1,4}:){1,3}(:[0-9a-fA-F]{1,4}){1,4}|([0-9a-fA-F]{1,4}:){1,2}(:[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:((:[0-9a-fA-F]{1,4}){1,6})|:((:[0-9a-fA-F]{1,4}){1,7}|:)|fe80:(:[0-9a-fA-F]{0,4}){0,4}%[0-9a-zA-Z]{1,}|::(ffff(:0{1,4}){0,1}:){0,1}((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])|([0-9a-fA-F]{1,4}:){1,4}:((25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9])\.){3,3}(25[0-5]|(2[0-4]|1{0,1}[0-9]){0,1}[0-9]))')
    return bool(ipv4_re.search(line)) or bool(ipv6_re.search(line))


def is_exception(line: str) -> bool:
    exception_re = re.compile(r'^@@')
    return bool(exception_re.match(line))


def is_regular(line: str) -> bool:
    regular_re = re.compile(r'^/')
    return bool(regular_re.match(line))


def gfwlist_line_filter(line: str) -> bool:
    line = line.strip()
    return (line != '') and (not is_comment(line)) and (not has_ip(line)) \
        and (not is_exception(line)) and (not is_regular(line))


def gfwlist_line_converter(line: str) -> str:
    raw_line = line

    line = line.strip()
    line = re.sub(r'/$', '', line)

    def invalid_rule_return():
        logging.debug('invalid rule: ' + raw_line)
        return ""

    def convert_asterisk(line: str) -> str:
        asterisk_re = re.compile(r'^[\w\-_]*\*[\w\-_]*\.')
        # 替换开头的 *.
        if re.match(asterisk_re, line):
            line = asterisk_re.sub("", line)
        # 移除中间含 * 的规则
        if '*' in line:
            return invalid_rule_return()
        return line

    # ||global.bing.com
    # ||cdn*.i-scmp.com
    if line.startswith('||'):
        line = line.replace('||', "")
        return convert_asterisk(line)

    # |http://www.dmm.com/netgame
    # |http://bbs.cantonese.asia/
    # |http://www.dmm.com/netgame
    # |http://*.1mobile.tw
    # |http://*2.bahamut.com.tw
    if line.startswith('|'):
        line = line.replace('|', '')
        line = re.sub(r'^http(s)?://', '', line)
        # 移除含有 path 的规则
        if '/' in line:
            return invalid_rule_return()
        return convert_asterisk(line)

    # .casinobellini.com
    # share.dmhy.org
    # .ddns.net/
    # bbs.sina.com%2F
    # .amazon.com/Dalai-Lama
    # amazon.com/Prisoner-State-Secret-Journal-Premier
    # .keepandshare.com/visit/visit_page.php?i=688154
    # .pentoy.hk/%E6%99%82%E4%BA%8B
    # .ruanyifeng.com/blog*some_ways_to_break_the_great_firewall
    # prisoner-state-secret-journal-premier
    # q%3Dfreedom
    # search*safeweb
    # q=triangle
    # ultrareach

    # 移除非域名规则
    if '.' not in line:
        return invalid_rule_return()

    # 移除 http 协议头
    line = re.sub(r'^http(s)?://', '', line)

    # 移除含 path 、含 params 的规则
    for m in ['/', '?', '=']:
        if m in line:
            return invalid_rule_return()
    # 移除非asci字符
    if re.search(r'%\w\w', line):
        return invalid_rule_return()

    line = convert_asterisk(line)

    # 移除域名最开头的 .
    if line.startswith('.'):
        line = re.sub(r'^\.', "", line)

    return line


def is_valid_hostname(domain: str) -> bool:
    # https://stackoverflow.com/questions/1418423/the-hostname-regex
    domain_re = re.compile(
        r'^(?=.{1,255}$)[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?(?:\.[0-9A-Za-z](?:(?:[0-9A-Za-z]|-){0,61}[0-9A-Za-z])?)*\.?$')
    return bool(domain_re.match(domain))


def hosts_deduplicate(hosts: list[str]) -> list[str]:
    hosts = list(set(hosts))
    for h in hosts.copy():
        if not is_valid_hostname(h):
            logging.warning('{host} is invalid!'.format(host=h))
            hosts.remove(h)

    hosts_copy = hosts.copy()
    for v in hosts_copy:
        for k in hosts_copy:
            if k != v and k.endswith('.' + v):
                logging.debug('found duplicate: {k} {v}'.format(k=k, v=v))
                hosts.remove(k)

    return hosts


def get_gfwlist_hosts() -> list[str]:
    gfwlist_text = get_gfwlist_text()
    gfwlist_lines = gfwlist_text.splitlines()
    gfwlist_hosts = list(
        filter(
            lambda line: line != "",
            map(
                gfwlist_line_converter,
                filter(gfwlist_line_filter, gfwlist_lines)
            )
        )
    )
    return gfwlist_hosts


def get_custom_proxy_hosts() -> list[str]:
    cph_path = os.path.join(PWD, 'custom_proxy_hosts.json')
    if not os.path.exists(cph_path):
        return []
    else:
        with open(cph_path, 'r') as f:
            return json.load(f)


def get_proxy_hosts() -> list[str]:
    proxy_hosts = [
        *get_gfwlist_hosts(),
        *get_custom_proxy_hosts()
    ]

    proxy_hosts = hosts_deduplicate(proxy_hosts)
    logging.info('found {num} proxy host'.format(num=len(proxy_hosts)))
    return sorted(proxy_hosts)


def get_dnsmasq_text() -> str:
    rule_list = list(
        map(
            lambda host: "server=/{host}/{dns_ip}#{dns_port}".format(
                host=host, dns_ip=PROXY_DNS_IP, dns_port=PROXY_DNS_PORT
            ),
            get_proxy_hosts()
        )
    )
    return '\n'.join(rule_list)


def write_dnsmasq():
    dnsmasq_text = get_dnsmasq_text()
    with open(DNSMASQ_RULES_FILE, 'w') as f:
        f.write(dnsmasq_text)


def reload_dnsmasq():
    subprocess.run(["/etc/init.d/dnsmasq", "reload"])


def get_smartdns_domain_set() -> str:
    return '\n'.join(get_proxy_hosts())


def write_smartdns_domain_set():
    domain_set_text = get_smartdns_domain_set()
    with open(SMARTDNS_DOMAIN_SET_FILE, 'w') as f:
        f.write(domain_set_text)


def reload_openwrt_smartdns():
    subprocess.run(["/etc/init.d/smartdns", "reload"])


def reload_pc_smartdns():
    subprocess.run(["systemctl", "restart", "smartdns.service"])


def run_openwrt():
    write_smartdns_domain_set()
    reload_openwrt_smartdns()

    write_dnsmasq()
    reload_dnsmasq()


def run_pc():
    write_smartdns_domain_set()
    reload_pc_smartdns()


if __name__ == '__main__':
    import argparse

    logging.basicConfig(stream=sys.stderr, level=logging.INFO, format="%(levelname)s:%(message)s")

    parser = argparse.ArgumentParser()
    parser.add_argument("where", choices=["openwrt", "pc"], help="运行环境:openwrt 或 pc")

    args = parser.parse_args()
    if args.where == "openwrt":
        run_openwrt()
    elif args.where == "pc":
        run_pc()