# V1.1 09/11/2022 - Modified by Theodore Zacharia from the original to allow for mass image genenration via a
# batch file containing the image parameters.  This logic was originally outside in a shell script but it
# was brought into this script so that the initialisation of the scheduler and engine, which takes considerable time
# is done only ONCE per whole set of images.  Many minutes of processing time saved.
#
# -- coding: utf-8 --`
import argparse
import os
# engine
from stable_diffusion_engine import StableDiffusionEngine
# scheduler
from diffusers import LMSDiscreteScheduler, PNDMScheduler
# utils
import cv2
import numpy as np
# additional for batch
import sys
import csv
import time
from datetime import datetime
import subprocess

# write the tags if required, there are a number of libraries used to communicate with Exif,
# we could use IPTCInfo but this currently only supports JPG and as our images can be png as well
# I wanted to keep it simple with minimum requirements, thus just using subprocess to call exiftool
# instead BUT this must therefore be installed
def setMetaData(afilename, atheprompt):
    print ("Updating MetaData")
    # this will REPLACE any existing matching tags
    p = subprocess.run(["exiftool", "-overwrite_original", "-Keywords=ai art", "-Keywords=Stable Diffusion OpenVINO", "-By-line=Theodore Zacharia", "-Caption-Abstract=Prompt - "+atheprompt, afilename])
    # this will UPDATE/APPEND to any existing matching tags
    # p = subprocess.run(["exiftool", "-overwrite_original", "-Keywords+=ai art", "-Keywords+=Stable Diffusion OpenVINO", "-By-line=Theodore Zacharia", "-Caption-Abstract=Prompt - "+atheprompt, afilename])

def main(args):

    csvdics = []
    if args.batchfile is not None:
        # if you provide a batch file, switch to running multiple images herein, much faster
        csvfile = list(csv.reader(open(args.batchfile,'r')))

        # load up the batch file ready to go
        for row in csvfile:
            row_dict = {}
            for i in range(len(row)):
                row_dict['column_%s' % i] = row[i]
            csvdics.append(row_dict)

        for row in csvdics:
            print (row)
    else:
        # no batch file, use the params to create a one off batch file to use same logic for rest of script
        row_dict = {}
        row_dict['column_0'] = args.seed
        row_dict['column_1'] = args.strength
        row_dict['column_2'] = 0
        row_dict['column_3'] = args.guidance_scale
        row_dict['column_4'] = 0
        row_dict['column_5'] = args.num_inference_steps
        csvdics.append(row_dict)


    if args.seed is not None:
        np.random.seed(args.seed)

    # this stuff is slow, so only do it ONCE
    if args.init_image is None:
        scheduler = LMSDiscreteScheduler(
            beta_start=args.beta_start,
            beta_end=args.beta_end,
            beta_schedule=args.beta_schedule,
            tensor_format="np"
        )
    else:
        scheduler = PNDMScheduler(
            beta_start=args.beta_start,
            beta_end=args.beta_end,
            beta_schedule=args.beta_schedule,
            skip_prk_steps = True,
            tensor_format="np"
        )

    engine = StableDiffusionEngine(
        model = args.model,
        scheduler = scheduler,
        tokenizer = args.tokenizer
    )

    # now process the different values
    loopct=0
    for row in csvdics:
        batch_seed = row['column_0']
        batch_strength = row['column_1']
        batch_stength_steps = row['column_2']
        batch_guidance = row['column_3']
        batch_guidance_steps = row['column_4']
        batch_num_steps = row['column_5']

        print ("Processing with:", batch_seed, batch_strength, batch_stength_steps, batch_guidance, batch_guidance_steps, batch_num_steps, args.prompt)
        ftime = int(time.time())    # convert to int to lose the milliseconds
        if args.output is None :
            if args.longfilename is True:
                fprompt = args.prompt.replace(' ', '_')
                output = args.outputdir + "/img" + str(ftime) + "_se" + batch_seed + "_sr" + batch_strength + "_gu" + batch_guidance + "_st" + batch_num_steps + "_" + fprompt + ".png"
            else:
                output = args.outputdir + "/img" + str(ftime) + "_se" + batch_seed + "_sr" + batch_strength + "_gu" + batch_guidance + "_st" + batch_num_steps + ".png"
        else:
            output = args.outputdir + "/" + args.output
        print ("Will output to:", output)

        np.random.seed(int(batch_seed))
        image = engine(
            prompt = args.prompt,
            init_image = None if args.init_image is None else cv2.imread(args.init_image),
            mask = None if args.mask is None else cv2.imread(args.mask, 0),
            strength = float(batch_strength),
            num_inference_steps = int(batch_num_steps),
            guidance_scale = float(batch_guidance),
            eta = args.eta
        )
        
        cv2.imwrite(output, image)

        if args.iptc is True:
            setMetaData(output, args.prompt)
        
        print ("Image written to:", output , "at", datetime.now().strftime("%d/%m/%Y %H:%M:%S") )


# Mainline    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # pipeline configure
    parser.add_argument("-l", "--longfilename", default="false", action="store_true", help="set output file to long format name")
    parser.add_argument("--iptc", default="false", action="store_true", help="set parameters used to generate into IPTC tags")
    parser.add_argument("--model", type=str, default="bes-dev/stable-diffusion-v1-4-openvino", help="model name")
    # randomizer params
    parser.add_argument("--seed", type=int, default=None, help="random seed for generating consistent images per prompt")
    # scheduler params
    parser.add_argument("--beta-start", type=float, default=0.00085, help="LMSDiscreteScheduler::beta_start")
    parser.add_argument("--beta-end", type=float, default=0.012, help="LMSDiscreteScheduler::beta_end")
    parser.add_argument("--beta-schedule", type=str, default="scaled_linear", help="LMSDiscreteScheduler::beta_schedule")
    # diffusion params
    parser.add_argument("--num-inference-steps", type=int, default=32, help="num inference steps")
    parser.add_argument("--guidance-scale", type=float, default=7.5, help="guidance scale")
    parser.add_argument("--eta", type=float, default=0.0, help="eta")
    # tokenizer
    parser.add_argument("--tokenizer", type=str, default="openai/clip-vit-large-patch14", help="tokenizer")
    # prompt
    parser.add_argument("--prompt", type=str, default="Street-art painting of Emilia Clarke in style of Banksy, photorealism", help="prompt")
    # img2img params
    parser.add_argument("--init-image", type=str, default=None, help="path to initial image")
    parser.add_argument("--strength", type=float, default=0.5, help="how strong the initial image should be noised [0.0, 1.0]")
    # inpainting
    parser.add_argument("--mask", type=str, default=None, help="mask of the region to inpaint on the initial image")
    # input and output files
    parser.add_argument("--output", type=str, default=None, help="output image name")
    parser.add_argument("--outputdir", type=str, default=".", help="output image name dir")
    parser.add_argument("--batchfile", type=str, default=None, help="input batch file name, overrides command line values")
    args = parser.parse_args()
    main(args)