import base64


def parse(stamp: str):
    b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")
    if b[0] == 0x01:
        return DNSCrypt.parse(stamp)
    elif b[0] == 0x02:
        return DNSoverHTTPS.parse(stamp)
    elif b[0] == 0x03:
        return DNSoverTLS.parse(stamp)
    elif b[0] == 0x04:
        return DNSoverQUIC.parse(stamp)
    elif b[0] == 0x05:
        return ObliviousDoH.parse(stamp)
    elif b[0] == 0x81:
        return DNSCryptRelay.parse(stamp)
    elif b[0] == 0x85:
        return ObliviousDoHrelay.parse(stamp)
    elif b[0] == 0x00:
        return PlainDNS.parse(stamp)


# Code from https://github.com/DNSCrypt/dnscrypt-resolvers/blob/21fcbaf858112c63fed2a504714cc829bd654483/utils/format.py#L101-L141
class DNSCrypt:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    pk: str = None
    provider: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x01
        i = 0
        if b[i] != 0x01:
            raise ValueError()
        i = i + 1

        parsed = DNSCrypt()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr [:port])
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        # LP(pk)
        pk_len = b[i]
        i = i + 1
        if pk_len != 32:
            raise ValueError()
        hpk = b[i:i + pk_len].hex().upper()
        hpks = []
        for j in range(0, 16):
            hpks.append(hpk[j * 4: j * 4 + 4])
        parsed.pk = ":".join(hpks)
        i = i + pk_len

        # LP(providerName)
        provider_len = b[i]
        i = i + 1
        parsed.provider = b[i:i + provider_len].decode("utf-8")
        i = i + provider_len

        return parsed


class DNSoverHTTPS:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    path: str = None
    bootstrap_ipi: list[str] = []

    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x02
        i = 0
        if b[i] != 0x02:
            raise ValueError()
        i = i + 1

        parsed = DNSoverHTTPS()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        # VLP(hash1, hash2, ...hashn)
        last_element = False
        while True:
            if b[i] & 0x80 == 0:
                last_element = True
                hashx_len = b[i]
            else:
                hashx_len = b[i] ^ 0x80
            if hashx_len != 0 and hashx_len != 32:
                raise ValueError()
            i = i + 1
            if hashx_len > 0:
                hashx = b[i:i + hashx_len].hex()
                parsed.hashs.append(hashx)
                i = i + hashx_len
            if last_element:
                break

        # LP(hostname [:port])
        hostname_len = b[i]
        i = i + 1
        parsed.hostname = b[i:i + hostname_len].decode("utf-8")
        i = i + hostname_len

        # LP(path)
        path_len = b[i]
        i = i + 1
        parsed.path = b[i:i + path_len].decode("utf-8")
        i = i + path_len

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            last_element = False
            while True:
                if b[i] & 0x80 == 0:
                    last_element = True
                    bootstrap_ipx_len = b[i]
                else:
                    bootstrap_ipx_len = b[i] ^ 0x80
                i = i + 1
                if bootstrap_ipx_len > 0:
                    bootstrap_ipx = b[i:i + bootstrap_ipx_len].decode("utf-8")
                    parsed.bootstrap_ipi.append(bootstrap_ipx)
                    i = i + bootstrap_ipx_len
                if last_element:
                    break

        return parsed


class DNSoverTLS:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x03
        i = 0
        if b[i] != 0x03:
            raise ValueError()
        i = i + 1

        parsed = DNSoverTLS()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        # VLP(hash1, hash2, ...hashn)
        last_element = False
        while True:
            if b[i] & 0x80 == 0:
                last_element = True
                hashx_len = b[i]
            else:
                hashx_len = b[i] ^ 0x80
            if hashx_len != 0 and hashx_len != 32:
                raise ValueError()
            i = i + 1
            if hashx_len > 0:
                hashx = b[i:i + hashx_len].hex()
                parsed.hashs.append(hashx)
                i = i + hashx_len
            if last_element:
                break

        # LP(hostname[:port])
        hostname_len = b[i]
        i = i + 1
        parsed.hostname = b[i:i + hostname_len].decode("utf-8")
        i = i + hostname_len

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            last_element = False
            while True:
                if b[i] & 0x80 == 0:
                    last_element = True
                    bootstrap_ipx_len = b[i]
                else:
                    bootstrap_ipx_len = b[i] ^ 0x80
                i = i + 1
                if bootstrap_ipx_len > 0:
                    bootstrap_ipx = b[i:i + bootstrap_ipx_len].decode("utf-8")
                    parsed.bootstrap_ipi.append(bootstrap_ipx)
                    i = i + bootstrap_ipx_len
                if last_element:
                    break

        return parsed


class DNSoverQUIC:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x04
        i = 0
        if b[i] != 0x04:
            raise ValueError()
        i = i + 1

        parsed = DNSoverQUIC()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        # VLP(hash1, hash2, ...hashn)
        last_element = False
        while True:
            if b[i] & 0x80 == 0:
                last_element = True
                hashx_len = b[i]
            else:
                hashx_len = b[i] ^ 0x80
            if hashx_len != 0 and hashx_len != 32:
                raise ValueError()
            i = i + 1
            if hashx_len > 0:
                hashx = b[i:i + hashx_len].hex()
                parsed.hashs.append(hashx)
                i = i + hashx_len
            if last_element:
                break

        # LP(hostname[:port])
        hostname_len = b[i]
        i = i + 1
        parsed.hostname = b[i:i + hostname_len].decode("utf-8")
        i = i + hostname_len

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            last_element = False
            while True:
                if b[i] & 0x80 == 0:
                    last_element = True
                    bootstrap_ipx_len = b[i]
                else:
                    bootstrap_ipx_len = b[i] ^ 0x80
                i = i + 1
                if bootstrap_ipx_len > 0:
                    bootstrap_ipx = b[i:i + bootstrap_ipx_len].decode("utf-8")
                    parsed.bootstrap_ipi.append(bootstrap_ipx)
                    i = i + bootstrap_ipx_len
                if last_element:
                    break

        return parsed


class ObliviousDoH:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    hostname: str = None
    path: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x05
        i = 0
        if b[i] != 0x05:
            raise ValueError()
        i = i + 1

        parsed = ObliviousDoH()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(hostname [:port])
        hostname_len = b[i]
        i = i + 1
        parsed.hostname = b[i:i + hostname_len].decode("utf-8")
        i = i + hostname_len

        # LP(path)
        path_len = b[i]
        i = i + 1
        parsed.path = b[i:i + path_len].decode("utf-8")
        i = i + path_len

        return parsed


class DNSCryptRelay:
    addr: str = None

    @staticmethod
    def parse(stamp):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x81
        i = 0
        if b[i] != 0x81:
            raise ValueError()
        i = i + 1

        parsed = DNSCryptRelay()

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        return parsed


class ObliviousDoHrelay:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None
    hashs: list[str] = []
    hostname: str = None
    path: str = None
    bootstrap_ipi: list[str] = []

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x85
        i = 0
        if b[i] != 0x85:
            raise ValueError()
        i = i + 1

        parsed = ObliviousDoHrelay()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        # VLP(hash1, hash2, ...hashn)
        last_element = False
        while True:
            if b[i] & 0x80 == 0:
                last_element = True
                hashx_len = b[i]
            else:
                hashx_len = b[i] ^ 0x80
            if hashx_len != 0 and hashx_len != 32:
                raise ValueError()
            i = i + 1
            if hashx_len > 0:
                hashx = b[i:i + hashx_len].hex()
                parsed.hashs.append(hashx)
                i = i + hashx_len
            if last_element:
                break

        # LP(hostname [:port])
        hostname_len = b[i]
        i = i + 1
        parsed.hostname = b[i:i + hostname_len].decode("utf-8")
        i = i + hostname_len

        # LP(path)
        path_len = b[i]
        i = i + 1
        parsed.path = b[i:i + path_len].decode("utf-8")
        i = i + path_len

        # VLP(bootstrap_ip1, bootstrap_ip2, ...bootstrap_ipn) (optional)
        if i < len(b):
            last_element = False
            while True:
                if b[i] & 0x80 == 0:
                    last_element = True
                    bootstrap_ipx_len = b[i]
                else:
                    bootstrap_ipx_len = b[i] ^ 0x80
                i = i + 1
                if bootstrap_ipx_len > 0:
                    bootstrap_ipx = b[i:i + bootstrap_ipx_len].decode("utf-8")
                    parsed.bootstrap_ipi.append(bootstrap_ipx)
                    i = i + bootstrap_ipx_len
                if last_element:
                    break

        return parsed


class PlainDNS:
    dnssec: bool = False
    nolog: bool = False
    nofilter: bool = False
    addr: str = None

    @staticmethod
    def parse(stamp: str):
        b = base64.urlsafe_b64decode(stamp.removeprefix("sdns://") + "==")

        # 0x00
        i = 0
        if b[i] != 0x00:
            raise ValueError()
        i = i + 1

        parsed = PlainDNS()

        # props
        props = b[i]
        parsed.dnssec = not not ((props >> 0) & 1)
        parsed.nolog = not not ((props >> 1) & 1)
        parsed.nofilter = not not ((props >> 2) & 1)
        i = i + 8

        # LP(addr)
        addr_len = b[i]
        i = i + 1
        parsed.addr = b[i:i + addr_len].decode("utf-8")
        i = i + addr_len

        return parsed


if __name__ == '__main__':
    # DNSCrypt
    t = parse(
        "sdns://AQMAAAAAAAAAETk0LjE0MC4xNC4xNTo1NDQzILgxXdexS27jIKRw3C7Wsao5jMnlhvhdRUXWuMm1AFq6ITIuZG5zY3J5cHQuZmFtaWx5Lm5zMS5hZGd1YXJkLmNvbQ")
    print(t)

    # DoH
    t = parse("sdns://AgcAAAAAAAAADTIxNy4xNjkuMjAuMjIADWRucy5hYS5uZXQudWsKL2Rucy1xdWVyeQ")
    t = parse(
        "sdns://AgMAAAAAAAAADjE2My40Ny4xMTcuMTc2oMwQYNOcgym2K2-8fQ1t-TCYabmB5-Y5LVzY-kCPTYDmoPf1ryiAHod9ffOivij-FJ8ydKftKfE2_VA845jLqAsNoLNeBZUM-9gln5N1uhAYcLjDxMDsWlKXV-YxZ-neJqnooEROvWe7g_iAezkh6TiskXi4gr1QqtsRIx8ETPXwjffOoOZEumlj4zX-dly5l2sSsQ61QpS0JHd2TMs6OsyjrLL8ICquP7e_BeTIHEGU3KRFEdT5rzBHhuwa5yGECc9ioINVEGFkbC5hZGZpbHRlci5uZXQKL2Rucy1xdWVyeQ")
    print(t)

    # DoT
    t = parse("sdns://AwcAAAAAAAAABzEuMS4xLjEAD29uZS5vbmUub25lLm9uZQ")
    print(t)

    # DoQ
    t = parse("sdns://BAcAAAAAAAAABzEuMS4xLjEAD29uZS5vbmUub25lLm9uZQ")
    print(t)

    # oDoH
    t = parse("sdns://BQcAAAAAAAAADWpwLnRpYXJhcC5vcmcFL29kb2g")
    print(t)

    #  DNSCrypt relay
    t = parse("sdns://gQ04Ni4xMDYuNzQuMjE5")
    print(t)

    #  oDoH relay
    t = parse("sdns://hQcAAAAAAAAADDg5LjM4LjEzMS4zOAAYb2RvaC1ubC5hbGVrYmVyZy5uZXQ6NDQzBi9wcm94eQ")
    print(t)

    # Plain DNS
    t = parse("sdns://AAUAAAAAAAAABzEuMS4xLjE")
    print(t)