using Plots
using JSON
using Statistics
using DelimitedFiles

include("params.jl")

function load_history(filename)
    if !isfile(filename)
        error("File $filename not found")
    end
    
    # Read the data, skipping comment lines
    data_matrix = readdlm(filename, ' ', Float64, comments=true)
    
    # Convert back to array of tuples
    history = [(data_matrix[i, 1], data_matrix[i, 2], data_matrix[i, 3], 
                data_matrix[i, 4], data_matrix[i, 5], data_matrix[i, 6],
                data_matrix[i, 7], data_matrix[i, 8]) 
               for i in 1:size(data_matrix, 1)]
    
    println("Loaded $(length(history)) time steps from $filename")
    return history
end

#=
    grid_kwargs = (
        grid=true,
        gridwidth=1.0,
        gridcolor=:gray,
        gridalpha=0.6,
        # legend=:outerbottom,
        xlabelfontsize=10,
        xlabel="hours"

        #xticks=10    # 0:10:maximum(t),
        # yticks=7,   # More y-axis divisions
        # margin=3Plots.mm,
        # left_margin=5Plots.mm,
        # bottom_margin=5Plots.mm 
    )
    #=
    minorgrid=true,      # If supported by backend
    minorgridwidth=1.0,
    minorgridcolor=:gray,
    minorgridalpha=0.6
    )   
    =#

=#

########################################################

# --- helpers ---------------------------------------------------------------

const FIELD_NAMES = (
    :t, :T_in, :T_out, :P_in, :Q_wall, :E_solid, :E_oil, :Energy_in
)

"Turn Vector{NTuple{9,Float64}} into Dict{Symbol, Vector{Float64}}"
function unpack_history(history)
    cols = Dict(name => Float64[] for name in FIELD_NAMES)
    for tup in history
        for (i, name) in enumerate(FIELD_NAMES)
            push!(cols[name], tup[i])
        end
    end
    return cols
end

"Fast piecewise-linear interpolation with flat extrapolation (or NaN)."
function interp1_linear(t::AbstractVector, y::AbstractVector, tq::AbstractVector; extrap=:flat)
    @assert length(t) == length(y) "t and y must have same length"
    @assert issorted(t) "t must be sorted ascending"
    yq = similar(tq, Float64)
    t1, tN = first(t), last(t)
    y1, yN = first(y), last(y)
    for (k, tk) in pairs(tq)
        if tk <= t1
            yq[k] = extrap === :nan ? NaN : y1
        elseif tk >= tN
            yq[k] = extrap === :nan ? NaN : yN
        else
            i = searchsortedlast(t, tk)     # t[i] ≤ tk < t[i+1]
            t_lo, t_hi = t[i], t[i+1]
            y_lo, y_hi = y[i], y[i+1]
            α = (tk - t_lo) / (t_hi - t_lo)
            yq[k] = (1 - α) * y_lo + α * y_hi
        end
    end
    return yq
end

"Build a common time grid and return interpolated series for selected fields."
function resample_for_compare(history_A, history_B; fields=FIELD_NAMES)
#    function resample_for_compare(history_A, history_B; fields=(:T_in, :T_out, :P_in))
    A = unpack_history(history_A)
    B = unpack_history(history_B)

    # Work in seconds, then convert to hours for plotting
    tA, tB = A[:t], B[:t]
    @assert issorted(tA) && issorted(tB)

    # Choose a common grid: step = min(dt_A, dt_B), end = min(Tf_A, Tf_B)
    dtA = minimum(diff(tA))
    dtB = minimum(diff(tB))
    Tf = min(last(tA), last(tB))
    dt = min(dtA, dtB)
    tq = collect(0.0:dt:Tf)              # seconds
    tq_h = tq ./ 3600.0                   # hours for x-axis

    # Interpolate requested fields for both runs
    outA = Dict{Symbol,Vector{Float64}}()
    outB = Dict{Symbol,Vector{Float64}}()
    for f in fields
        outA[f] = interp1_linear(tA, A[f], tq; extrap=:flat)
        outB[f] = interp1_linear(tB, B[f], tq; extrap=:flat)
    end
    return tq_h, outA, outB
end

const grid_kwargs = (
    grid=true,
    gridwidth=1.0,
    gridcolor=:gray,
    gridalpha=0.6,
    # legend=:outerbottom,
    xlabelfontsize=8,
    xlabel="hours",
    ylabelfontsize=8,

    #xticks=10    # 0:10:maximum(t),
    # yticks=7,   # More y-axis divisions
    # margin=3Plots.mm,
    # left_margin=5Plots.mm,
    # bottom_margin=5Plots.mm 
)

function v_percent(skip, AA, BB)
    function percent(A, B)
        tol = max(0.01*abs(A), 0.01*abs(B), 1.0)
        if abs(A) > tol
            return 100*(A - B)/A
        end
        return 0
    end    
    C = []
    for (A, B) in zip(AA[skip:end], BB[skip:end])
        push!(C, percent(A, B))
    end
    return C
end

function main()

    println("\n ########################################### \n")

    conf_id = length(ARGS) >= 2 ? ARGS[2] : "x"
    config_file = "config_$conf_id.yaml"

    params = Params(config_file)

    tulosdir = length(ARGS) >= 1 ? ARGS[1] : "tulokset/default"
    runA_id = ARGS[2]
    runB_id = ARGS[3]

    # Load your histories
    history_A = load_history(joinpath(tulosdir, "$(runA_id)_history.dat"))
    history_B = load_history(joinpath(tulosdir, "$(runB_id)_history.dat"))

    # Pick which series to compare
    # t_h, A, B = resample_for_compare(history_A, history_B; fields=(:T_in, :T_out, :P_in))
    t_h_all, A, B = resample_for_compare(history_A, history_B; fields=FIELD_NAMES)

    # ohitetaan alusta 200 sekuntia, koska alussa prosentit "heittelehtivät",
    # koska pienissä luvuissa on suuria muutoksia pikkuisen eri aikaan
    # resample_for_compare interpoloi ja auttaa
    dt = t_h_all[2] - t_h_all[1]
    skip = floor(Int, 200.0 / 3600.0 / dt)
    println("dt: $dt  skip: $skip")

    t_h = t_h_all[skip:end]
    # E_oil_solid = E_oil .+ E_solid
    E_oil_solid_A = A[:E_oil] .+ A[:E_solid]
    E_oil_solid_B = B[:E_oil] .+ B[:E_solid]

    #=
    t0 = length(params.phases) > 1 ? params.phases[1][2] : 2*params.t_end
    dP_in = []
    for (time, dP) in zip(t_h, v_percent(skip, A[:P_in], B[:P_in]))
        if time < t0 - 600 || time > t0 + 600
            push!(dP_in, dP)
        else
            push!(dP_in, 0.0)
        end
    end
=#
    p1 = plot(t_h, v_percent(skip, A[:T_in], B[:T_in]), label="T_in", xlabel="time [h]", ylabel="ero %" )
    # plot!(p1, t_h, v_percent(skip, B[:T_in]), label="T_in B")

    p2 = plot(t_h, v_percent(skip, A[:T_out], B[:T_out]), label="T_out", xlabel="time [h]", ylabel="ero %" )
    # plot!(p2, t_h, v_percent(skip, B[:T_out]), label="T_out B")

    p3 = plot(t_h, v_percent(skip, A[:P_in], B[:P_in]), label="P_in", xlabel="time [h]", ylabel="ero %" )

    p4 = plot(t_h, v_percent(skip, A[:E_solid], B[:E_solid]), label="sylinterin lämpömäärä", xlabel="time [h]", ylabel="ero %" )

    p5 = plot(t_h, v_percent(skip, A[:E_oil], B[:E_oil]), label="öljyn lämpömäärä", xlabel="time [h]", ylabel="ero %" )

    p6 = plot(t_h, v_percent(skip, A[:Energy_in], B[:Energy_in]), label="energia sisään", xlabel="time [h]", ylabel="ero %" )

    p7 = plot(t_h, v_percent(skip, E_oil_solid_A, E_oil_solid_B), label="oil + solid lämpömäärä", xlabel="time [h]", ylabel="ero %" )

    layout = @layout [p1 p2; p3 p4; p5 p6; p7 x]
    combined_plot = plot(p1, p2, p6, p4, p5, p3, p7,
        layout=layout, size=(1900, 1000)
    )

    fig_file = joinpath(tulosdir, "$(runA_id)-$(runB_id).png")
    savefig(combined_plot, fig_file)

    gui(combined_plot)  # Open in GUI window
    println("Press Enter to close the plot window...")
    readline()

end

# Run main function if this is the main script
if abspath(PROGRAM_FILE) == @__FILE__
    main()
end

#=

    # display(combined_plot)

      # If running as a script, keep window open
    
      if isinteractive()
        combined_plot
    else
        gui(combined_plot)  # Open in GUI window
        println("Press Enter to close the plot window...")
        readline()
    end
    
p1 = plot(time, T_in), label="T_in", 
          grid=true,
          gridwidth=2,           # Thicker grid lines
          gridcolor=:gray,       # Grid color
          gridalpha=0.8)         # Grid opacity (0-1)

p1 = plot(time, T_in), label="T_in", 
          grid=true,
          xticks=10)  # Approximately 10 ticks on x-axis

# Or set exact tick positions
p1 = plot(time, T_in), label="T_in", 
          grid=true,
          xticks=0:1000:maximum(t))  # Every 1000 units

p1 = plot(time, T_in), label="T_in",
          grid=true,
          gridwidth=1,
          gridcolor=:lightgray,
          gridalpha=0.7,
          minorgrid=true,      # If supported by backend
          minorgridwidth=0.5,
          minorgridcolor=:lightgray,
          minorgridalpha=0.3)

# Enhanced plotting with better grids
    grid_kwargs = (
        grid=true,
        gridwidth=1.5,
        gridcolor=:gray,
        gridalpha=0.6,
        xticks=8,  # More x-axis divisions
        yticks=6   # More y-axis divisions
    )
    
    p1 = plot(time, T_in), label="T_in"; grid_kwargs...)
    p2 = plot(time, T_diff), label="T_in - T_out"; grid_kwargs...)
    p3 = plot(time, Energy_in), label="Energy_in"; grid_kwargs...)
    p4 = plot(time, Q_solid), label="Q_solid"; grid_kwargs...)
    p5 = plot(time, P), label="heater power"; grid_kwargs...)
    p6 = plot(time, Q_pipe), label="Q_pipe"; grid_kwargs...)
    p7 = plot(time, Q_err), label="Q_err"; grid_kwargs...)
    p8 = plot(time, T_out), label="T_out"; grid_kwargs...)          

=#

#=
    tf = 3600.0
    s = 5.0  # 6.6
    U = 600.0
    sigma = 1.1
    wn = s/tf
    dt = 1.0
    t_end = Int(1.5*tf/dt)
    tx = [Float64(i)*dt for i in 1:t_end]
    x1 = [0.0 for t in tx]
    x2 = [0.0 for t in tx]
    y = [0.0 for t in tx]

    x1[1] = 273.9
    x2[1] = 0.0

    for i = 2:t_end
        y[i] = 1.0 - (1.0 + wn*tx[i])*exp(-wn*tx[i])
        dx1 = x2[i-1]
        dx2 = - wn^2*x1[i-1] - 2*wn*sigma*x2[i-1] + wn^2*U
        # println("$i; $(x1[i]); $(x2[i]); $dx1; $dx2")
        x1[i] = x1[i-1] + dx1*dt
        x2[i] = x2[i-1] + dx2*dt
    end
=#