import base64
import json

from .parser import props, LP_decode, LP_pk, VLP_hashs, VLP_bootstrap_ipi


# Document
# https://dnscrypt.info/stamps-specifications
# https://dnscrypt.info/protocol
# https://github.com/DNSCrypt/dnscrypt-resolvers/blob/21fcbaf858112c63fed2a504714cc829bd654483/utils/format.py#L101-L141

def parse(stamp: str):
    b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")
    if b[0] == 0x01:
        return DNSCrypt.parse(stamp)
    elif b[0] == 0x02:
        return DNSoverHTTPS.parse(stamp)
    elif b[0] == 0x03:
        return DNSoverTLS.parse(stamp)
    elif b[0] == 0x04:
        return DNSoverQUIC.parse(stamp)
    elif b[0] == 0x05:
        return ObliviousDoH.parse(stamp)
    elif b[0] == 0x81:
        return DNSCryptRelay.parse(stamp)
    elif b[0] == 0x85:
        return ObliviousDoHRelay.parse(stamp)
    elif b[0] == 0x00:
        return PlainDNS.parse(stamp)


class Base:
    def to_json(self):
        return json.dumps(
            self,
            default=lambda o: o.__dict__
        )

    def __str__(self):
        return self.__class__.__name__ + '<' + self.to_json() + '>'


class DNSCrypt(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    pk: str = None
    provider: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x01
        i = 0
        if b[i] != 0x01:
            raise ValueError("DNSCrypt: 0x01")
        i = i + 1

        parsed = DNSCrypt()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr [:port])
        i, parsed.addr = LP_decode(b, i)

        # LP(pk)
        i, parsed.pk = LP_pk(b, i)

        # LP(providerName)
        i, parsed.provider = LP_decode(b, i)

        return parsed


class DNSoverHTTPS(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    path: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x02
        i = 0
        if b[i] != 0x02:
            raise ValueError("DNSoverHTTPS: 0x02")
        i = i + 1

        parsed = DNSoverHTTPS()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        # VLP(hash1, hash2, ...hashn)
        i, parsed.hashs = VLP_hashs(b, i)

        # LP(hostname [:port])
        i, parsed.hostname = LP_decode(b, i)

        # LP(path)
        i, parsed.path = LP_decode(b, i)

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            i, parsed.bootstrap_ipi = VLP_bootstrap_ipi(b, i)

        return parsed


class DNSoverTLS(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x03
        i = 0
        if b[i] != 0x03:
            raise ValueError()
        i = i + 1

        parsed = DNSoverTLS()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        # VLP(hash1, hash2, ...hashn)
        i, parsed.hashs = VLP_hashs(b, i)

        # LP(hostname[:port])
        i, parsed.hostname = LP_decode(b, i)

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            i, parsed.bootstrap_ipi = VLP_bootstrap_ipi(b, i)

        return parsed


class DNSoverQUIC(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x04
        i = 0
        if b[i] != 0x04:
            raise ValueError()
        i = i + 1

        parsed = DNSoverQUIC()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        # VLP(hash1, hash2, ...hashn)
        i, parsed.hashs = VLP_hashs(b, i)

        # LP(hostname[:port])
        i, parsed.hostname = LP_decode(b, i)

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            i, parsed.bootstrap_ipi = VLP_bootstrap_ipi(b, i)

        return parsed


class ObliviousDoH(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    hostname: str = None
    path: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x05
        i = 0
        if b[i] != 0x05:
            raise ValueError()
        i = i + 1

        parsed = ObliviousDoH()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(hostname [:port])
        i, parsed.hostname = LP_decode(b, i)

        # LP(path)
        i, parsed.path = LP_decode(b, i)

        return parsed


class DNSCryptRelay(Base):
    addr: str = None

    @staticmethod
    def parse(stamp):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x81
        i = 0
        if b[i] != 0x81:
            raise ValueError()
        i = i + 1

        parsed = DNSCryptRelay()

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        return parsed


class ObliviousDoHRelay(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    path: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x85
        i = 0
        if b[i] != 0x85:
            raise ValueError()
        i = i + 1

        parsed = ObliviousDoHRelay()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        # VLP(hash1, hash2, ...hashn)
        i, parsed.hashs = VLP_hashs(b, i)

        # LP(hostname [:port])
        i, parsed.hostname = LP_decode(b, i)

        # LP(path)
        i, parsed.path = LP_decode(b, i)

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            i, parsed.bootstrap_ipi = VLP_bootstrap_ipi(b, i)

        return parsed


class PlainDNS(Base):
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x00
        i = 0
        if b[i] != 0x00:
            raise ValueError()
        i = i + 1

        parsed = PlainDNS()

        # props
        i, parsed.dnssec, parsed.nolog, parsed.nofilter = props(b, i)

        # LP(addr)
        i, parsed.addr = LP_decode(b, i)

        return parsed