#!/usr/bin/env python

# priodithpng
# Ordered Dither

# test with:
# priforgepng grl | priodithpng | kitty icat

# See http://www.efg2.com/Lab/Library/ImageProcessing/DHALF.TXT
# archived at http://web.archive.org/web/20160727202727/http://www.efg2.com/Lab/Library/ImageProcessing/DHALF.TXT

import sys

# https://docs.python.org/3.5/library/bisect.html
from bisect import bisect_left


import png


def dither(
    out,
    input,
    bitdepth=1,
):
    """Dither the input PNG `inp` into an image with a smaller bit depth
    and write the result image onto `out`.
    `bitdepth` specifies the bit depth of the new image.
    """

    # Encoding is what happened when the PNG was made (and also what
    # happens when we output the PNG).  Decoding is what we do to the
    # source PNG in order to process it.

    # The dithering algorithm is not completely general; it
    # can only do bit depth reduction, not arbitrary palette changes.

    maxval = 2 ** bitdepth - 1
    r = png.Reader(file=input)

    _, _, rows, info = r.asDirect()
    planes = info["planes"]
    # :todo: make an Exception
    assert planes == 1
    width = info["size"][0]
    sourcemaxval = 2 ** info["bitdepth"] - 1

    dithered_rows = run_dither(rows, info)

    ninfo = dict(info)
    ninfo["bitdepth"] = bitdepth
    w = png.Writer(**ninfo)
    w.write(out, dithered_rows)


def run_dither(rows, info):
    """Do ordered dither."""

    # dither map
    M = [
        [0, 2],
        [3, 1],
    ]

    dw = len(M[0])
    dh = len(M)

    R = dh * dw

    D = 2 ** info["bitdepth"] - 1

    bias = sum(sum(r) for r in M) / R

    _, width = info["size"]

    for j, row in enumerate(rows):
        targetrow = [0] * width

        for i, v in enumerate(row):
            mi = i % dw
            mj = j % dh

            mo = (M[mj][mi] - bias ) / R
            k = v / D + mo
            # Select nearest output code
            targetrow[i] = int(k + 0.5)

        # print(targetrow, file=sys.stderr)
        yield targetrow


def main(argv=None):
    # https://docs.python.org/3.5/library/argparse.html
    import argparse

    parser = argparse.ArgumentParser()

    if argv is None:
        argv = sys.argv

    progname, *args = argv

    parser.add_argument("--bitdepth", type=int, default=1, help="bitdepth of output")
    parser.add_argument(
        "input", nargs="?", default="-", type=png.cli_open, metavar="PNG"
    )

    ns = parser.parse_args(args)

    return dither(png.binary_stdout(), **vars(ns))


if __name__ == "__main__":
    main()
