using Plots
using JSON
using Statistics
using DelimitedFiles

include("params.jl")

function load_history(filename)
    if !isfile(filename)
        error("File $filename not found")
    end
    
    # Read the data, skipping comment lines
    data_matrix = readdlm(filename, ' ', Float64, comments=true)
    
    # Convert back to array of tuples
    history = [ 
            (data_matrix[i, 1], data_matrix[i, 2], data_matrix[i, 3], 
                data_matrix[i, 4], data_matrix[i, 5], data_matrix[i, 6],
                data_matrix[i, 7], data_matrix[i, 8])
            for i in 1:size(data_matrix, 1)
                ]
    
    println("Loaded $(length(history)) time steps from $filename")
    return history
end

function plot_history(history, ta, tb, datadir, ajo_id)

    # Initialize arrays
    time = Float64[]
    T_in = Float64[]
    T_out = Float64[]
    P_in = Float64[]
    Q_wall = Float64[]
    E_solid = Float64[]
    Energy_in = Float64[]
    E_oil = Float64[]
    Error = Float64[]
    # Q_wall_cumul = Float64[]

    grid_kwargs = (
        grid=true,
        gridwidth=1.0,
        gridcolor=:gray,
        gridalpha=0.6,
        # legend=:outerbottom,
        xlabelfontsize=10,
        xlabel="hours"

        #xticks=10    # 0:10:maximum(t),
        # yticks=7,   # More y-axis divisions
        # margin=3Plots.mm,
        # left_margin=5Plots.mm,
        # bottom_margin=5Plots.mm 
    )
    #=
    minorgrid=true,      # If supported by backend
    minorgridwidth=1.0,
    minorgridcolor=:gray,
    minorgridalpha=0.6
    )   
    =#

    N = length(history)
    for (t, T_in_i, T_out_i, P_in_i, Q_wall_i, E_solid_i, E_oil_i, Energy_in_i) in history
        if t >= ta && t <= tb
            push!(time, t / 3600.0)
            push!(T_in, T_in_i)
            push!(T_out, T_out_i)
            push!(P_in, P_in_i)
            push!(Q_wall, Q_wall_i)
            push!(E_solid, E_solid_i)
            push!(E_oil, E_oil_i)
            push!(Energy_in, Energy_in_i)
            # push!(Q_wall_cumul, Q_wall_cumul_i)
        end
    end

    T_diff = T_in .- T_out
    E_oil_solid = E_oil .+ E_solid

    for (E_in, E_stored, t) in zip(Energy_in, E_oil_solid, time)
        if 3600 * t > 300.0
            push!(Error, 100.0 * (E_in - E_stored) / E_in)
        else
            push!(Error, 0.0)
        end
    end

    # Create subplots
    p1 = plot(time, T_in, label="T_in"; grid_kwargs...)
    plot!(p1, time, T_out, label="T_out"; grid_kwargs...)
    p2 = plot(time, T_diff, label="T_in - T_out"; grid_kwargs...)
    p3 = plot(time, Energy_in, label="Sisään syötetty energia"; grid_kwargs...)
    plot!(p3, time, E_oil_solid, label="sylinterin ja öljyn lämpömäärä"; grid_kwargs...)
    p4 = plot(time, E_solid, label="sylinteriin varastoitunut lämpö"; grid_kwargs...)
    # plot!(p4, time, Q_wall_cumul, label="cumulative heat from oil to solid"; grid_kwargs...)
    p5 = plot(time, E_oil, label="öljyyn varastoitunut lämpö"; grid_kwargs...)
    p6 = plot(time, P_in, label="heater output"; grid_kwargs...)
    plot!(p6, time, Q_wall, label="heat flow from oil to solid"; grid_kwargs...)
    p7 = plot(time, Error, label="Virhe energiataseessa %"; grid_kwargs...)

    # Combine into grid layout
    layout = @layout [p1 p2; p3 p4; p5 p6; p7 p8]
    combined_plot = plot(p1, p2, p6, p4, p5, p3, p7,
        layout=layout, size=(1900, 1000)
    )
    
    fig_file = joinpath(datadir, "$(ajo_id)_history_$(round(Int,ta))-$(round(Int,tb)).png")
    savefig(combined_plot, fig_file)

    #=
    gui(combined_plot)  # Open in GUI window
    println("Press Enter to close the plot window...")
    readline()
    =#
end

########################################################

const aikaikkuna = 3600

function plot_all(params, tulosdir, conf_id)
    history_file = joinpath(tulosdir, "$(conf_id)_history.dat")

    history = load_history(history_file)

    plot_history(history, 0, params.t_end, tulosdir, conf_id)
    plot_history(history, 0, aikaikkuna, tulosdir, conf_id)
    if length(params.phases) > 1
        t0 = params.phases[1][2]
        plot_history(history, t0 - aikaikkuna/2.0, t0 + 1.5*aikaikkuna, tulosdir, conf_id)
    end

end

# Run main function if this is the main script
#####################################################################
# CLI
if abspath(PROGRAM_FILE) == @__FILE__

    run_id = length(ARGS) >= 1 ? ARGS[1] : "ajo"
    tulosdir = joinpath("tulokset", run_id)
    if !ispath(tulosdir)
        mkpath(tulosdir)
    end

    conf_id = length(ARGS) >= 2 ? ARGS[2] : "x"
    config_file = "config_$conf_id.yaml"

    params = Params(config_file)
    plot_all(params, tulosdir, conf_id)
end

#=
    # display(combined_plot)

      # If running as a script, keep window open
    
      if isinteractive()
        combined_plot
    else
        gui(combined_plot)  # Open in GUI window
        println("Press Enter to close the plot window...")
        readline()
    end
    
p1 = plot(time, T_in, label="T_in", 
          grid=true,
          gridwidth=2,           # Thicker grid lines
          gridcolor=:gray,       # Grid color
          gridalpha=0.8)         # Grid opacity (0-1)

p1 = plot(time, T_in, label="T_in", 
          grid=true,
          xticks=10)  # Approximately 10 ticks on x-axis

# Or set exact tick positions
p1 = plot(time, T_in, label="T_in", 
          grid=true,
          xticks=0:1000:maximum(t))  # Every 1000 units

p1 = plot(time, T_in, label="T_in",
          grid=true,
          gridwidth=1,
          gridcolor=:lightgray,
          gridalpha=0.7,
          minorgrid=true,      # If supported by backend
          minorgridwidth=0.5,
          minorgridcolor=:lightgray,
          minorgridalpha=0.3)

# Enhanced plotting with better grids
    grid_kwargs = (
        grid=true,
        gridwidth=1.5,
        gridcolor=:gray,
        gridalpha=0.6,
        xticks=8,  # More x-axis divisions
        yticks=6   # More y-axis divisions
    )
    
    p1 = plot(time, T_in, label="T_in"; grid_kwargs...)
    p2 = plot(time, T_diff, label="T_in - T_out"; grid_kwargs...)
    p3 = plot(time, Energy_in, label="Energy_in"; grid_kwargs...)
    p4 = plot(time, Q_solid, label="Q_solid"; grid_kwargs...)
    p5 = plot(time, P, label="heater power"; grid_kwargs...)
    p6 = plot(time, Q_pipe, label="Q_pipe"; grid_kwargs...)
    p7 = plot(time, Q_err, label="Q_err"; grid_kwargs...)
    p8 = plot(time, T_out, label="T_out"; grid_kwargs...)          

=#