#pragma once

// interleaves code for several modes:
// sampling_mode single: traces one sample per pixel
// sampling_mode multi: retraces all samples in a pixel
// sampling_mode adaptive: retraces pixels according to heuristic

// shading_mode single: shades one sample per pixel
// shading_mode multi: shades all samples in a pixel
// shading_mode adaptive: shades pixels according to heuristic

uint pack_cache_idx(uint3 val)
{
    const uint3 num_bits = uint3(14, 14, 4);
    const uint3 shifts = uint3(0, num_bits.x, num_bits.x + num_bits.y);
    const uint3 masks = (1u << num_bits) - 1u;
    const uint3 packed_components = (val & masks) << shifts;
    return packed_components.x | packed_components.y | packed_components.z;
}

uint3 unpack_cache_idx(uint val)
{
    const uint3 num_bits = uint3(14, 14, 4);
    const uint3 shifts = uint3(0, num_bits.x, num_bits.x + num_bits.y);
    const uint3 masks = (1u << num_bits) - 1u;
    return (val >> shifts) & masks;
}

struct first_hit_t
{
    uint instance_id;
    uint prim_id;
    uint mat_id;
    float2 bary;
};

struct sample_entry_t
{
    first_hit_t fh;
    cached_radiance_t radiance;
    void init()
    {
        fh.instance_id = UINT_MAX;
        fh.prim_id = 0;
        fh.mat_id = 0;
        fh.bary = 0.0f;
        radiance = init_cached_radiance();
    }
};

sample_entry_t empty_entry()
{
    sample_entry_t entry;
    entry.init();
    return entry;
}

uint3 sample_entry_dims()
{
    if (sc_has_radiance)
        return sample_entry_dims_diffuse;
    return sample_entry_dims_fh;
}

// Encodes a float2 into a uint2 with 24-bit UNORM values
uint2 encodeFloat2ToUint2(float2 input) {
    uint2 result;
    result.x = uint(round(clamp(input.x, 0.0, 1.0) * float(0xFFFFFF))); // Encode x
    result.y = uint(round(clamp(input.y, 0.0, 1.0) * float(0xFFFFFF))); // Encode y
    return result;
}

// Decodes a uint2 with 24-bit UNORM values back to a float2
float2 decodeUint2ToFloat2(uint2 encoded) {
    float2 result;
    result.x = float(encoded.x) / float(0xFFFFFF); // Decode x
    result.y = float(encoded.y) / float(0xFFFFFF); // Decode y
    return result;
}


float4 encode_first_hit(first_hit_t hit)
{
    // layout: bary (2x24-bit), mat_id (16-bit), instance_id (32-bit), prim_id (32-bit)
    float4 val;
    if (hit.instance_id == UINT_MAX)
    {
        val = float4(hit.bary, asfloat(hit.prim_id), asfloat(hit.instance_id));
    }
    else
    {
        uint2 encoded_bary = encodeFloat2ToUint2(hit.bary);
        val.x = asfloat((encoded_bary.x & 0xFFFFFF) | ((hit.mat_id & 0xFF) << 24u));
        val.y = asfloat((encoded_bary.y & 0xFFFFFF) | ((hit.mat_id & 0xFF00) << 16u));
        val.z = asfloat(hit.prim_id);
        val.w = asfloat(hit.instance_id);
    }
    return val;
}

first_hit_t decode_first_hit(float4 val)
{
    // layout: bary (2x24-bit), mat_id (16-bit), instance_id (32-bit), prim_id (32-bit)
    first_hit_t hit;
    hit.instance_id = asuint(val.w);
    if (hit.instance_id == UINT_MAX)
    {
        hit.bary = val.xy;
        hit.prim_id = asuint(val.z);
        hit.mat_id = UINT_MAX;
    }
    else
    {
        uint2 encoded_bary = uint2(asuint(val.x) & 0xFFFFFF, asuint(val.y) & 0xFFFFFF);
        hit.bary = decodeUint2ToFloat2(encoded_bary);
        hit.mat_id = ((asuint(val.x) >> 24u) & 0xFF) | ((asuint(val.y) >> 16u) & 0xFF00);
        hit.prim_id = asuint(val.z);
    }

    return hit;
}

uint retrace_decision(
    uint3 gId,
    uint sample_idx,
    uint3 res,
    inout uint mask_val,
    CameraState cam,
    RWTexture2DArray<int> depthmin,
    RWTexture2DArray<int64_t> depthmax,
    RWTexture2DArray<float4> cache_fh
)
{
    static const uint2 trace_block = uint2(4,4);

    uint trace_mask = calculate_masks(mask_val, num_sample_slots()).uninitialized_samples_mask;
    uint occupant_samples = calculate_masks(mask_val, num_sample_slots()).occupancy_mask;
    uint num_samples = countbits(occupant_samples);

    // mark a subset of samples always so that we can catch animations and geometry coming from the near plane
    const uint trace_block_size = trace_block.x * trace_block.y;
    const uint2 blockpixel = gId.xy % trace_block;
    const uint lin_blockpixel = trace_block.x * blockpixel.y + blockpixel.x;
    if (sample_idx % trace_block_size == lin_blockpixel)
    {
        uint round_robin_sample = (sample_idx / trace_block_size) % num_samples;
        trace_mask |= (1u << round_robin_sample);
    }

    static const int kernel_size = 9;
    static const float3 kernel_offsets[5] = {
        {0, 0, 0}, {-1, 0, 0 }, {1, 0, 0}, {0, -1, 0 }, {0, 1, 0 }
    };

    float min_dist = FLT_MAX;
    float slope = 0;
    for (uint i = 0; i < kernel_size; i++)
    {
        int3 p = gId + kernel_offsets[i];
        if (any(p.xy < 0) || any(p.xy >= res.xy))
            continue;

        int64_t depthslope = depthmax[p];
        // max_dist = max(max_dist, cam.proj[2][3] / asfloat(depthmin[p]));
        min_dist = min(min_dist, cam.proj[2][3] / asfloat(uint(depthslope >> 32u)));
        slope = max(slope, asfloat(uint(depthslope)));
    }

    uint visible_samples = get_visible_samples_mask(mask_val);
    for (uint i = 0; i < num_samples; i++)
    {
        uint z = firstbitlow(visible_samples);
        visible_samples &= ~(1u << z);

        uint3 cache_idx = decode_cache_idx(gId, z);
        first_hit_t fh = decode_first_hit(cache_fh[cache_idx]);
        const LocalGeometry lg = get_local_geometry(fh.instance_id, fh.prim_id, fh.bary);

        float lin_depth = (cam.proj[2][3]) / cam.worldToScreenUv(lg.pos).z; // linearize depth

        float threshold = 1.01 * (min_dist + 2.1 * slope);
        bool retrace_pixel = lin_depth > threshold;

        if (retrace_pixel)
            trace_mask |= (1u << z);
    }

    return trace_mask;
}

// updates mask_val to create/remove samples (or set/unset visibility when adaptive sampling)
// and returns a bitmask of samples that shall be traced.
uint density_analysis(
    uint3 gId,
    uint frame_idx,
    uint3 img_res,
    CameraState cam,
    uint num_slots,
    float target_density,
    float density_tolerance,
    uint sampling_mode,
    uint shading_mode,
    RWTexture2DArray<uint> cache_mask,
    RWTexture2DArray<float4> cache_data,
    RWTexture2DArray<int> depthmin,
    RWTexture2DArray<int64_t> depthmax,
    RWStructuredBuffer<shading_statistics_t> shading_statistics,
    RWTexture2DArray<float4> dbg,
    inout uint mask_val
)
{
    // when adaptively sampling, a heuristic can mark samples as invisible to
    // trigger the creation of more samples
    uint trace_mask = 0;
    float2 dvar = 0.0;
    if (sampling_mode == ADAPTIVE_SAMPLING)
    {
        trace_mask = retrace_decision(gId, frame_idx, img_res, mask_val, cam, depthmin, depthmax, cache_data);
    }

    const float density = measure_density(cache_mask, gId, num_slots);
    const uint num_visible_samples = count_visible_samples(mask_val, gId, num_slots);
    const float sample_tolerance = density_tolerance; // max(1.0, 0.5 * target_density);
    const float delta_D = target_density - density;

    // update sample mask
    int delta_N = 0;
    if (num_visible_samples == 0)
        delta_N = 1;
    else if (abs(delta_D) > sample_tolerance)
        delta_N = int(sign(delta_D) * ceil(abs(delta_D)));
    else
        delta_N = 0;

    int min_N = 1 - countbits(get_active_samples(mask_val));
    int max_N = num_slots - int(num_visible_samples);
    // max_N = min(max_N, 1);
    delta_N = clamp(delta_N, min_N, max_N);

    if (sampling_mode == SINGLE_SAMPLING)
        delta_N = min(delta_N, 1);

    if (delta_N > 0)
    {
        for (uint j = 0; j < abs(delta_N); j++)
            create_sample(mask_val, gId);

        if (shading_mode == ADAPTIVE_SHADING)
            InterlockedAdd(shading_statistics[frame_idx % 3].bins[0], delta_N);
    }
    else if (delta_N < 0)
    {
        for (uint j = 0; j < abs(delta_N); j++)
            remove_sample(mask_val, gId);
    }

    // choosing which rays to trace:
    // single: choose a new or an arbitrary sample (round-robin) to trace
    // multi: trace all samples, new and old
    // adaptive: trace probably occluded and new ones always, stable samples in extended round-robin
    if (sampling_mode == SINGLE_SAMPLING)
        trace_mask = (1u << choose_target_sample(mask_val, num_slots));
    else if (sampling_mode == MULTI_SAMPLING)
        trace_mask = get_active_samples(mask_val);
    else
        trace_mask |= calculate_masks(mask_val, num_slots).uninitialized_samples_mask;

    return trace_mask;
}

uint reshade_decision(
    uint3 img_idx,
    uint frame_idx,
    uint3 img_res,
    uint3 cache_res,
    uint mask_val,
    uint num_slots,
    float target_density,
    float shading_rate,
    RWTexture2DArray<float4> cache_fh,
    RWTexture2DArray<float4> cache_diffuse,
    RWStructuredBuffer<shading_statistics_t> shading_statistics,
    RWTexture2DArray<float4> dbg_tex
)
{
    uint shade_mask = 0u;
    uint occupant_samples = calculate_masks(mask_val, num_sample_slots()).occupancy_mask;
    uint num_samples = countbits(occupant_samples);

    // mark one sample in every 4x4 block as active to get a baseline shading rate
    const uint shade_block_size = shade_block.x * shade_block.y;
    const uint2 blockpixel = img_idx.xy % shade_block;
    const uint lin_blockpixel = shade_block.x * blockpixel.y + blockpixel.x;
    if (frame_idx % shade_block_size == lin_blockpixel)
    {
        uint round_robin_sample = (frame_idx / shade_block_size) % num_samples;
        shade_mask |= (1u << round_robin_sample);
    }
    
    uint visible_samples_mask = calculate_masks(mask_val, num_slots).visible_samples_mask;
    uint num_visible_samples = countbits(visible_samples_mask);

    // we cheat a bit by only considering the importance of the neighbourhood and assuming similar values for all samples in the pixel
    float importance = 0;
    uint num_importance_samples = num_visible_samples;
    uint sample_id = frame_idx % num_importance_samples;
    for (uint i = 0; i < 4; i++)
    {
        static const uint3 kernel_offsets[4] = {
            uint3(0,0,0), uint3(1,0,0), uint3(0, 1, 0), uint3(1,1,0)
        };
        uint3 p = img_idx + kernel_offsets[i];
        if (any(img_idx + kernel_offsets[i] >= img_res))
            continue;
        const uint3 cache_idx = decode_cache_idx(img_idx, sample_id);
        const first_hit_t fh = decode_first_hit(cache_fh[cache_idx]);
        const float4 diffuse = cache_diffuse[cache_idx];
        float sample_importance = fh.instance_id != UINT_MAX ? 1.0f / (1 + diffuse.a) : importance_min_val;
        importance += (sample_importance / num_importance_samples);
    }

    uint stats_idx = frame_idx % 3;
    uint next_stats_idx = (frame_idx + 1) % 3;
    shading_statistics_t stats = shading_statistics[stats_idx];
    float lowest_bin = asfloat(stats.lowest_bin);
    float highest_bin = asfloat(stats.highest_bin);
    float log_lowest_bin = log2(lowest_bin);
    float log_highest_bin = log2(highest_bin);
    const uint target_sum = uint(shading_rate * (img_res.x * img_res.y * img_res.z));

    stats = histogram_postsum(frame_idx, img_res, stats);

    if (importance > importance_min_val)
    {
        uint i = 0;
        if (importance <= lowest_bin)
            i = num_bins-1;
        else if (importance >= highest_bin)
            i = 0;
        else
            // compute correct bin via remap (invlerp followed by lerp)
            i = clamp(remapf<uint>(log_highest_bin, log_lowest_bin, 0, num_bins, log2(importance)), 0, num_bins-1);

        float bound0 = i > 0 ? exp2(lerp(log_highest_bin, log_lowest_bin, float(i) / num_bins)): importance_max_val;
        float bound1 = i < num_bins-1 ? exp2(lerp(log_highest_bin, log_lowest_bin, float(i+1) / num_bins)) : importance_min_val;
        uint presum = i > 0 ? stats.bins[i-1] : uint(1.0 / shade_block_size * (img_res.x * img_res.y * img_res.z));
        uint postsum = stats.bins[i];

        // per-sample decison
        if (postsum < target_sum)
        {
            shade_mask |= visible_samples_mask;
        }
        else if (presum < target_sum)
        {
            uint bin = postsum - presum;
            uint target = target_sum - presum;

            uint iter_mask = visible_samples_mask;
            for (uint sample_id = 0; sample_id < num_visible_samples; sample_id++)
            {
                uint target_sample = firstbitlow(iter_mask);
                iter_mask &= ~(1u << target_sample);

                float prob = bin > 0 ? clamp(float(target) / float(bin), 0.0f, 1.0f) : 1.0f;

                // choose "randomly" according to bayer pattern
                static const float bayer8x8[64] = {
                    0.000, 0.500, 0.125, 0.625, 0.031, 0.531, 0.156, 0.656,
                    0.750, 0.250, 0.875, 0.375, 0.781, 0.281, 0.906, 0.406,
                    0.188, 0.688, 0.062, 0.562, 0.219, 0.719, 0.094, 0.594,
                    0.938, 0.438, 0.812, 0.312, 0.969, 0.469, 0.844, 0.344,
                    0.047, 0.547, 0.172, 0.672, 0.016, 0.516, 0.141, 0.641,
                    0.797, 0.297, 0.922, 0.422, 0.766, 0.266, 0.891, 0.391,
                    0.234, 0.734, 0.109, 0.609, 0.203, 0.703, 0.078, 0.578,
                    0.984, 0.484, 0.859, 0.359, 0.953, 0.453, 0.828, 0.328
                };
                uint2 block = uint2(8, 16);
                uint2 pixel = img_idx.xy % block;
                uint dither_idx = pixel.y * block.x + pixel.x;
                dither_idx = (dither_idx + frame_idx * 37 + target_sample * 11) % 64;
                float dither_val = bayer8x8[dither_idx];
                if (dither_val < prob)
                    shade_mask |= (1u << target_sample);

            }
        }
    }

    return shade_mask;
}