import os
import csv
import time
import numpy as np
import cv2
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from picamera2 import Picamera2
import tifffile

# USER DEFINED PARAMETERS (to change for desired settings) 

USER_PARAMS = {
    # Camera parameters example
    "duration_s":    8,   # in s
    "framerate":     1/8,  # in nºFrames/s
    "resolution":    (1280, 800),
    "analogue_gain": 15,  # sensor’s hardware-level amplification
    "brightness": 0,
    "contrast": 1,
    "NoiseReductionMode": 0,
    "pixel_format":  "YUV420",
    "consecutive_stable": 3,
    "buffer_count":  10,
    "potentiostat_venv": "/path to venv file",
    "potentiostat_script": "/path to potentiostat script",
    
    # Analysis parameters (for ECL analysis)
	# Electrode area
    "electrode_area": 0.150,  # in cm^2
	#CV
    "cv_start_pot": 0.0,  # in V
    "cv_end_pot": 2,  # in V
    "cv_n_cycles": 1,  # nº of cycles
    "cv_scan_rate": 0.1,  # in V/s
	#CA
    "ca_start_time": 0,  # in s
    "ca_applied_potential": 2,  # in V for CA
    
    # ROI parameters
    "roi_shape": "circle",  # "circle" or "rectangle"
    "num_rois": 3,  # Number of ROIs to select
    "fixed_radius": 90,  # pixels, for circle ROI
    "fixed_width": 170,  # pixels, for rectangle ROI
    "fixed_height": 170,  # pixels, for rectangle ROI
    
    # Analysis script path
    "analysis_script": "/path to analysis script",
    
    # Auto-run analysis after capture
    "run_analysis": True
}

# ---------------------------
# FUNCTIONS DEFINITIONS
# ---------------------------

def create_results_folder(base="results"):
    now = datetime.now()
    folder_name = f"{base}_{now.strftime('%Y-%m-%d_%H%M%S')}"
    os.makedirs(folder_name, exist_ok=True)
    return folder_name

def calculate_camera_params(user_params):
    exposure_us = int(1_000_000 / user_params["framerate"])
    frame_count = int(user_params["duration_s"] * user_params["framerate"])
    params = user_params.copy()
    params.update({
        "exposure_us": exposure_us,
        "frame_count": frame_count,
    })
    return params

def _as_uint8_1d(buffer_like):
    if isinstance(buffer_like, np.ndarray):
        arr = buffer_like.ravel()
        if arr.dtype != np.uint8:
            arr = arr.view(np.uint8)
    else:
        arr = np.frombuffer(buffer_like, dtype=np.uint8)
    return arr

def extract_y_plane_from_yuv420(buffer_like, width, height):
    arr = _as_uint8_1d(buffer_like)
    expected_y_size = width * height
    if arr.size < expected_y_size:
        raise ValueError(f"YUV buffer too small: {arr.size} < {expected_y_size}")
    return arr[:expected_y_size].reshape((height, width)).copy()

def save_image(filename, image):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    cv2.imwrite(filename, image)

def _safe_save_raw_tiff(path, raw_array):
    out_path = path if path.lower().endswith(".tiff") else path + ".tiff"
    tifffile.imwrite(out_path, raw_array)

def wait_for_stabilization(picam2, target_exposure, target_gain, tolerance=0.01, consecutive=3, max_checks=100):
    last_exp, last_gain = [], []
    for _ in range(max_checks):
        meta = picam2.capture_metadata()
        exp = meta.get("ExposureTime", None)
        gain = meta.get("AnalogueGain", None)
        if exp is not None and gain is not None:
            last_exp.append(abs(exp - target_exposure) < target_exposure * tolerance)
            last_gain.append(abs(gain - target_gain) < target_gain * tolerance)
            if len(last_exp) > consecutive: last_exp.pop(0); last_gain.pop(0)
            if all(last_exp[-consecutive:]) and all(last_gain[-consecutive:]): return
        time.sleep(0.1)

def discard_Frames_for_seconds(picam2, seconds):
    start = time.time()
    while time.time() - start < seconds:
        picam2.capture_request().release()

def start_potentiostat_in_venv(venv_path, script_path, results_folder):
    python_exe = os.path.join(venv_path, "bin", "python")
    if not os.path.exists(python_exe): raise RuntimeError(f"Python not found: {python_exe}")
    if not os.path.exists(script_path): raise RuntimeError(f"Script not found: {script_path}")
    return [python_exe, script_path]

def stream_subprocess_output(process):
    for line in iter(process.stdout.readline, b''):
        print(line.decode(errors="replace"), end='')
    process.stdout.close()

# ---------------------------
# 10-bit RAW unpacking
# ---------------------------
def unpack_y10(raw_bytes, width, height):
    data = np.frombuffer(raw_bytes, dtype=np.uint8)
    data = data.reshape(height, width * 5 // 4)
    unpacked = np.zeros((height, width), dtype=np.uint16)
    for i in range(0, width, 4):
        b0 = data[:, (i//4)*5 + 0]
        b1 = data[:, (i//4)*5 + 1]
        b2 = data[:, (i//4)*5 + 2]
        b3 = data[:, (i//4)*5 + 3]
        b4 = data[:, (i//4)*5 + 4]
        unpacked[:, i+0] = ((b0.astype(np.uint16)<<2) | (b4 & 0b00000011))
        unpacked[:, i+1] = ((b1.astype(np.uint16)<<2) | ((b4 & 0b00001100)>>2))
        unpacked[:, i+2] = ((b2.astype(np.uint16)<<2) | ((b4 & 0b00110000)>>4))
        unpacked[:, i+3] = ((b3.astype(np.uint16)<<2) | ((b4 & 0b11000000)>>6))
    return unpacked

# Capture function

def capture_gray_and_raw_Frames(params, results_folder):
    frame_folder = os.path.join(results_folder, "Frames")
    raw_folder = os.path.join(frame_folder, "raw_y10")
    tiff_folder = os.path.join(frame_folder, "raw_tiff")
    png_folder = os.path.join(frame_folder, "png")
    os.makedirs(frame_folder, exist_ok=True)
    os.makedirs(raw_folder, exist_ok=True)
    os.makedirs(tiff_folder, exist_ok=True)

    metadata_csv = os.path.join(results_folder, "Frames", "metadata_Frames.csv")
    picam2 = Picamera2()

    controls = {
        "FrameRate": params["framerate"],
        "ExposureTime": params["exposure_us"],
        "AnalogueGain": params["analogue_gain"],
        "AeEnable": False,
        "NoiseReductionMode": params["NoiseReductionMode"],
        "Brightness": params["brightness"],
        "Contrast": params["contrast"],
    }

    config = picam2.create_still_configuration(
        main={"format": params["pixel_format"], "size": params["resolution"]},
        raw={"size": params["resolution"]},
        buffer_count=params.get("buffer_count", 10),
        controls=controls
    )
    picam2.configure(config)
    picam2.start()

    wait_for_stabilization(picam2, params["exposure_us"], params["analogue_gain"],
                           tolerance=0.1, consecutive=params.get("consecutive_stable",3))
    discard_Frames_for_seconds(picam2, 1.0)

    # --- Start potentiostat ---
    pot_cmd = start_potentiostat_in_venv(params["potentiostat_venv"], params["potentiostat_script"], results_folder)
    pot_env = os.environ.copy()
    pot_env["EXPERIMENT_RESULTS_FOLDER"] = results_folder
    pot_proc = subprocess.Popen(pot_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=pot_env)
    stream_thread = threading.Thread(target=stream_subprocess_output, args=(pot_proc,))
    stream_thread.start()

    # --- Wait for potentiostat sync ---
    measurement_start_path = os.path.join(results_folder, "measurement_start.txt")
    while not os.path.exists(measurement_start_path):
        time.sleep(0.02)

    # --- Capture sync frame ---
    request = picam2.capture_request()
    yuv_buffer = request.make_buffer("main")
    raw_bytes = request.make_buffer("raw")
    meta = request.get_metadata()
    request.release()

    width, height = params["resolution"]
    y_plane = extract_y_plane_from_yuv420(yuv_buffer, width, height)
    save_image(os.path.join(png_folder, "frame_sync.png"), y_plane)

    raw_path = os.path.join(raw_folder, "frame_sync.raw")
    with open(raw_path, "wb") as f:
        f.write(raw_bytes)

    # --- Log camera first frame ---
    with open(measurement_start_path, "a") as f:
        f.write(f"CameraFirstFrame_wallclock={time.time()}\n")
        f.write(f"CameraSensorTimestamp_us={meta.get('SensorTimestamp',0)}\n")
        f.flush(); os.fsync(f.fileno())

    # --- Capture Frames ---
    interval = 1.0 / params["framerate"]
    next_frame_time = time.time()

    with open(metadata_csv, "w", newline="") as csvfile, ThreadPoolExecutor(max_workers=2) as executor:
        writer = csv.writer(csvfile)
        writer.writerow([
            "Frame", "FrameWallClock", "SensorTimestamp_us", "FrameDuration", "ExposureTime",
            "AnalogueGain", "DigitalGain", "Lux", "AeState", "SensorBlackLevels",
            "FocusFoM", "ScalerCrop"
        ])

        for i in range(params["frame_count"]):
            request = picam2.capture_request()
            yuv_buffer = request.make_buffer("main")
            raw_bytes = request.make_buffer("raw")
            meta = request.get_metadata()
            request.release()

            y_plane = extract_y_plane_from_yuv420(yuv_buffer, width, height)
            png_path = os.path.join(png_folder, f"frame_{i:04d}.png")
            raw_path = os.path.join(raw_folder, f"frame_{i:04d}.raw")

            executor.submit(save_image, png_path, y_plane)
            with open(raw_path, "wb") as f:
                f.write(raw_bytes)

            writer.writerow([
                i, meta.get("FrameWallClock",0), meta.get("SensorTimestamp",0),
                meta.get("FrameDuration",0), meta.get("ExposureTime",0),
                meta.get("AnalogueGain",0), meta.get("DigitalGain",0), meta.get("Lux",0),
                meta.get("AeState",0), meta.get("SensorBlackLevels",0),
                meta.get("FocusFoM",0), meta.get("ScalerCrop",0)
            ])

            next_frame_time += interval
            sleep_time = next_frame_time - time.time()
            if sleep_time > 0: time.sleep(sleep_time)

    picam2.stop()
    pot_proc.wait()
    stream_thread.join()

    # --- Post-process: unpack raw -> 16-bit TIFF ---
    raw_files = sorted([f for f in os.listdir(raw_folder) if f.endswith(".raw")])
    for raw_file in raw_files:
        raw_path = os.path.join(raw_folder, raw_file)
        with open(raw_path, "rb") as f:
            raw_bytes = f.read()
        raw16 = unpack_y10(raw_bytes, width, height)
        tiff_name = os.path.splitext(raw_file)[0] + ".tiff"
        _safe_save_raw_tiff(os.path.join(tiff_folder, tiff_name), raw16)


def encode_to_mp4(params, results_folder):
    png_folder = os.path.join(results_folder, "Frames", "png")
    mp4_file = os.path.join(results_folder, "Frames", "video.mp4")
    ffmpeg_cmd = [
        "ffmpeg", "-y", "-framerate", str(params["framerate"]),
        "-i", f"{png_folder}/frame_%04d.png",
        "-c:v", "libx264", "-pix_fmt", "yuv420p", "-preset", "veryfast", "-crf", "23",
        mp4_file
    ]
    subprocess.run(ffmpeg_cmd, check=True)

def call_roi_capture_script(results_folder):
    script_path = "/path to initial electrode capture"
    subprocess.run(["python3", script_path, results_folder], check=True)

def call_analysis_script(params, results_folder):
    """Call the ECL analysis script with all necessary parameters"""
    analysis_script = params.get("analysis_script")
    if not analysis_script or not os.path.exists(analysis_script):
        print(f"Warning: Analysis script not found at {analysis_script}")
        return
    
    print("\n" + "="*60)
    print("RUNNING ECL ANALYSIS")
    print("="*60 + "\n")
    
    # Pass results folder and parameters as command line arguments
    cmd = [
        "python3", 
        analysis_script,
        results_folder,
        str(params["framerate"]),
        str(params["cv_start_pot"]),
        str(params["cv_end_pot"]),
        str(params["cv_n_cycles"]),
        str(params["cv_scan_rate"]),
        str(params["ca_start_time"]),
        str(params["electrode_area"]),
        str(params["ca_applied_potential"]),
        params["roi_shape"],
        str(params["num_rois"]),
        str(params["fixed_radius"]),
        str(params["fixed_width"]),
        str(params["fixed_height"])
    ]
    
    try:
        subprocess.run(cmd, check=True)
        print("\nAnalysis completed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"\nAnalysis failed with error: {e}")

def main():
    results_folder = create_results_folder()
    params = calculate_camera_params(USER_PARAMS)
    
    try:
        # Step 1: ROI capture
        call_roi_capture_script(results_folder)
        
        # Step 2: Capture Frames and run potentiostat
        capture_gray_and_raw_Frames(params, results_folder)
        
        # Step 3: Encode video
        encode_to_mp4(params, results_folder)
        
        # Step 4: Run analysis if enabled
        if params.get("run_analysis", True):
            call_analysis_script(params, results_folder)
        
        print(f"\n{'='*60}")
        print(f"EXPERIMENT COMPLETE")
        print(f"Results saved in: {results_folder}")
        print(f"{'='*60}\n")
        
    except Exception as e:
        print(f"Experiment failed: {e}")
        raise

if __name__ == "__main__":
    main()
