#pragma once


// this code interleaves the implementations of our sample cache with a reimplementation of interactive stable ray tracing:
// mask_mode 0 => our mask and sample layout
// mask_mode 1 => interactive stable ray tracing layout


struct MaskData
{
    uint occupancy_mask;
    uint visibility_mask;
    uint visible_samples_mask;
    uint invisible_samples_mask;
    uint uninitialized_samples_mask;
    uint active_samples_mask;
};

uint num_sample_slots()
{
    return mask_mode == 0 ? g_lubo.screen_space_num_layers : g_lubo.screen_M * g_lubo.screen_M;
}

uint occupancy_bits(uint num_slots)
{
    return (1u << num_slots) - 1u;
}

uint visibility_bits(uint num_slots)
{
    return occupancy_bits(num_slots) << num_slots;
}

MaskData calculate_masks(uint mask_val, uint num_slots)
{
    MaskData mask;
    mask.occupancy_mask = mask_val & occupancy_bits(num_slots);
    mask.visibility_mask = (mask_val & visibility_bits(num_slots)) >> num_slots;
    mask.visible_samples_mask = mask.occupancy_mask & mask.visibility_mask;
    mask.invisible_samples_mask = mask.occupancy_mask & (~mask.visibility_mask);
    mask.uninitialized_samples_mask = (~mask.occupancy_mask) & mask.visibility_mask;
    mask.active_samples_mask = mask.occupancy_mask | mask.visibility_mask;
    return mask;
}

uint get_occupancy_mask(uint mask_val)
{
    uint num_slots = num_sample_slots();
    return calculate_masks(mask_val, num_slots).occupancy_mask;
}

// gets mask of occupied or visible samples
uint get_active_samples(uint mask_val)
{
    uint num_slots = num_sample_slots();
    return calculate_masks(mask_val, num_slots).active_samples_mask;
}

uint get_visible_samples_mask(uint mask_val)
{
    uint num_slots = num_sample_slots();
    return calculate_masks(mask_val, num_slots).visible_samples_mask;
}

uint count_samples(uint mask_val, uint3 gId, uint num_slots)
{
    MaskData mask = calculate_masks(mask_val, num_slots);
    const uint num_samples = countbits(mask.occupancy_mask);
    return num_samples;
}

uint count_visible_samples(uint mask_val, uint3 gId, uint num_slots)
{
    MaskData mask = calculate_masks(mask_val, num_slots);
    const uint num_visible_samples = countbits(mask.visible_samples_mask);
    return num_visible_samples;
}

bool is_sample_occupied(uint mask_val, uint sample_idx)
{
    return (mask_val & (1u << sample_idx)) == (1u << sample_idx);
}

bool is_sample_visible(uint mask_val, uint sample_idx, uint num_slots)
{
    const uint visibility_bit = (1u << (sample_idx + num_slots));
    return (mask_val & visibility_bit) == visibility_bit;
}

// is sample occupied and visible
bool is_sample_valid(uint mask_val, uint sample_idx, uint num_slots)
{
    const uint occupancy_bit = (1u << (sample_idx));
    const uint visibility_bit = (1u << (sample_idx + num_slots));
    return (mask_val & (occupancy_bit | visibility_bit)) == (occupancy_bit | visibility_bit);
}

bool is_mask_full(uint mask_val, uint num_slots)
{
    return mask_val == (occupancy_bits(num_slots) | visibility_bits(num_slots));
}

// Helper to enable a sample bit as visible but unoccupied (code for an unitialized sample)
uint make_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    const uint occupancy_mask = ~(1u << sample_idx);
    const uint visibility_mask = (1u << sample_idx) << num_slots;
    return (mask_val & occupancy_mask) | visibility_mask;
}

// Helper to enable a sample bit as visible after making it
uint enable_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    return (mask_val | (1u << sample_idx) | (1u << (sample_idx + num_slots)));
}

// Helper to disable a sample bit
uint disable_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    return mask_val & ~(1u << sample_idx) & ~(1u << (sample_idx + num_slots));
}

uint set_visible_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    return (mask_val | (1u << (sample_idx + num_slots)));
}

uint set_invisible_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    return mask_val & ~(1u << (sample_idx + num_slots));
}

float measure_density(RWTexture2DArray<uint> mask_tex, uint3 gId, uint num_slots)
{
    const int3 kernel[9] = {
        int3(-1, -1, 0), int3(+0, -1, 0), int3(+1, -1, 0),
        int3(-1, +0, 0), int3(+0, +0, 0), int3(+1, +0, 0),
        int3(-1, +1, 0), int3(+0, +1, 0), int3(+1, +1, 0)
    };

    uint num_samples = 0;
    float w = 0;
    for (uint i = 0; i < 9; i++)
    {
        uint3 dst_mask_pixel = gId + kernel[i];
        if (all(dst_mask_pixel < dim3(mask_tex)))
        {
            uint n = countbits(mask_tex[dst_mask_pixel] & occupancy_bits(num_slots));
            float weight = all(kernel[i] == 0) ? 4 : 1;
            num_samples += weight * n;
            w += weight;
        }
    }
    return num_samples / w;
}


/*****************************************************************************/

uint linearize_subpixel(uint2 subpixel)
{
    return subpixel.y * g_lubo.screen_M + subpixel.x;
}

uint2 decode_subpixel(uint slot)
{
    return uint2(slot % g_lubo.screen_M, slot / g_lubo.screen_M);
}

uint16_t set_random_bits(uint16_t set_mask, uint N)
{
    uint16_t result = 0;
    const uint num_bits = 16 - countbits(uint(set_mask));
    uint remainingBits = num_bits;
    uint16_t mask = set_mask;
    for (uint16_t i = 0; i < min(N, num_bits); ++i)
    {
        seed = (seed * 1664525u + 1013904223u);
        int randBit = int(seed % remainingBits);
        uint16_t bitIndex = 0;
        for (bitIndex = 0; bitIndex < 16; bitIndex++)
        {
            if ((mask & (1u << bitIndex)) == 0)
            {
                if (randBit == 0)
                    break;
                randBit--;
            }
        }
        result |= uint16_t(1u << bitIndex);
        mask |= uint16_t(1u << bitIndex);
        remainingBits--;
    }
    return result;
}

// set N random bits except those already set in set_mask
uint choose_random_sample(uint mask_val, uint num_slots)
{
    const uint available_bits = mask_val & occupancy_bits(num_slots);
    const uint num_bits = countbits(available_bits);
    seed = (seed * 1664525u + 1013904223u);
    int randBit = int(seed % num_bits);
    uint bitIndex = 0;
    for (bitIndex = 0; bitIndex < num_slots; bitIndex++)
    {
        if ((mask_val & (1u << bitIndex)) != 0)
        {
            if (randBit == 0)
                break;
            randBit--;
        }
    }
    return bitIndex;
}

/*****************************************************************************/
/* switch functions for mask_mode */

uint3 decode_cache_idx(uint3 gId, uint slot)
{
    if (mask_mode == 0)
        return uint3(gId.xy, gId.z * g_lubo.screen_space_num_layers + slot);
    else
        return uint3(g_lubo.screen_M * gId.xy + decode_subpixel(slot), gId.z);
}

uint3 compute_cache_idx(
    uint3 cache_res,
    float2 screen_uv,
    uint eye,
    uint target_slot)
{
    if (mask_mode == 0)
    {
        const uint num_layers_per_eye = g_lubo.screen_space_num_layers;
        return uint3(screen_uv * cache_res.xy, eye * num_layers_per_eye + target_slot);
    }
    else
    {
        return uint3(screen_uv * cache_res.xy, eye);
    }
}

bool try_reserve_slot(
    uint3 cache_res,
    uint3 gId,
    bool src_visible,
    float3 dst_uv,
    RWTexture2DArray<uint> dst_cache_mask,
    out uint target_slot
)
{
    if (mask_mode == 0)
    {
        // // this code is basically equivalent and closer to the pseudo-code in the paper, but otherwise untested
        // const uint3 img_res = g_lubo.display_res;
        // const uint num_slots = g_lubo.screen_space_num_layers;
        // const uint3 img_destination = uint3(dst_uv.xy * img_res.xy, gId.z);
        // uint dst_mask_val = dst_cache_mask[img_destination];
        // bool replace_sample = false;
        // target_slot = 0;

        // if (any(img_destination >= img_res) || dst_uv.z < 0 || dst_uv.z > 1)
        //     return replace_sample;

        // // try to find a slot until the target is completely full!
        // while (!replace_sample && !is_mask_full(dst_mask_val, num_slots))
        // {
        //     MaskData dst_mask = calculate_masks(dst_mask_val, num_slots);
        //     bool has_empty_slots = dst_mask.occupancy_mask != occupancy_bits(num_slots);
        //     bool has_invisible_slots = (~dst_mask.visible_samples_mask) != 0u;
        //     bool overwrite_unoccupied = false;
        //     if (has_empty_slots)
        //     {
        //         target_slot = firstbitlow(((1u << num_slots) - 1u) & ~dst_mask.occupancy_mask);
        //         overwrite_unoccupied = false;
        //     }
        //     else if (src_visible && has_invisible_slots)
        //     {
        //         target_slot = firstbitlow(((1u << num_slots) - 1u) & ~dst_mask.visible_samples_mask);
        //         overwrite_unoccupied = true;
        //     }
        //     else
        //         break;
        //     const uint o_bit = 1u << (target_slot);
        //     const uint v_bit = 1u << (target_slot + num_slots);
        //     uint s_val = o_bit;
        //     if (src_visible)
        //         s_val |= v_bit;
        //     InterlockedOr(dst_cache_mask[img_destination], s_val, dst_mask_val);
        //     replace_sample = (dst_mask_val & o_bit) != o_bit;
        //     if (src_visible && !replace_sample && overwrite_unoccupied)
        //         replace_sample = (dst_mask_val & v_bit) != v_bit;
        //     dst_mask_val |= s_val;
        // }
        // return replace_sample;

        const uint3 img_res = g_lubo.display_res;
        const uint num_layers_per_eye = g_lubo.screen_space_num_layers;
        const uint3 img_destination = uint3(dst_uv.xy * img_res.xy, gId.z);
        uint dst_mask_val = dst_cache_mask[img_destination];
        bool replace_sample = false;
        target_slot = 0;

        if (any(img_destination >= img_res) || dst_uv.z < 0 || dst_uv.z > 1)
            return replace_sample;

        // try to find a slot until the target is completely full!
        while (!is_mask_full(dst_mask_val, num_layers_per_eye))
        {
            MaskData dst_mask = calculate_masks(dst_mask_val, num_layers_per_eye);

            if (dst_mask.occupancy_mask != occupancy_bits(num_layers_per_eye))
            {
                target_slot = firstbitlow(~dst_mask.occupancy_mask);
                const uint dst_occupancy_bit = 1u << (target_slot);
                const uint dst_visibility_bit = 1u << (target_slot + num_layers_per_eye);
                uint dst_val = dst_occupancy_bit;
                if (src_visible)
                    dst_val |= dst_visibility_bit;
                InterlockedOr(dst_cache_mask[img_destination], dst_val, dst_mask_val);
                if ((dst_mask_val & dst_occupancy_bit) != dst_occupancy_bit)
                {
                    replace_sample = true;
                    break;
                }
                dst_mask_val |= dst_val;
            }
            else if (src_visible)
            {
                target_slot = firstbitlow(~dst_mask.visible_samples_mask);
                const uint dst_occupancy_bit = 1u << (target_slot);
                const uint dst_visibility_bit = 1u << (target_slot + num_layers_per_eye);
                uint dst_val = dst_occupancy_bit;
                if (src_visible)
                    dst_val |= dst_visibility_bit;
                InterlockedOr(dst_cache_mask[img_destination], dst_val, dst_mask_val);
                if ((dst_mask_val & dst_visibility_bit) != dst_visibility_bit)
                {
                    replace_sample = true;
                    break;
                }
                dst_mask_val |= dst_val;
            }
            else
            {
                break;
            }
        }
        return replace_sample;
    }
    else
    {
        // Interactive Stable Ray Tracing
        uint2 M = uint2(g_lubo.screen_M, g_lubo.screen_M);
        uint3 cache_idx = uint3(dst_uv.xy * cache_res.xy, gId.z);
        uint2 dst_mask_pixel = cache_idx.xy / M.xy;
        uint2 slot = cache_idx.xy % M.xy;
        uint dst_subpixel = slot.y * M.x + slot.x;
        uint dst_occupancy_bit = 1u << dst_subpixel;
        uint dst_visibility_bit = 1u << (dst_subpixel + M.x * M.y);
        uint bit_mask = dst_occupancy_bit | (src_visible ? dst_visibility_bit : 0u);
        uint original_bitmask = 0;
        InterlockedOr(dst_cache_mask[uint3(dst_mask_pixel, gId.z)], bit_mask, original_bitmask);
        bool dst_occupied = (dst_occupancy_bit & original_bitmask) == dst_occupancy_bit;
        bool original_is_visible = (dst_visibility_bit & original_bitmask) == dst_visibility_bit;

        bool replace_sample = !dst_occupied || (src_visible && !original_is_visible);
        if (replace_sample)
            target_slot = dst_subpixel;
        else
            target_slot = 0;

        return replace_sample;
    }

}

uint choose_nth_bit(uint mask_val, int n)
{
    const int num_bits = countbits(mask_val);
    for (int idx = 0; idx < min(num_bits - 1 , n); idx++)
        mask_val &= ~(1u << firstbitlow(mask_val));
    return firstbitlow(mask_val);
}

uint choose_target_sample(uint mask_val, uint num_slots)
{
    MaskData mask = calculate_masks(mask_val, num_slots);
    if (mask.uninitialized_samples_mask != 0u)
        return firstbithigh(mask.uninitialized_samples_mask);
    if (mask_mode == 0)
        return g_lubo.sample_idx % countbits(mask.occupancy_mask);
    else
        return choose_random_sample(mask.occupancy_mask, num_slots);
}

uint choose_shading_sample(uint mask_val, uint sample_idx, uint num_slots)
{
    uint visible_mask = calculate_masks(mask_val, num_slots).visible_samples_mask;
    uint nb_bits = countbits(visible_mask);
    uint n = sample_idx % nb_bits;
    return choose_nth_bit(visible_mask, n) ;
}

void remove_sample(inout uint mask_val, uint3 gId)
{
    if (mask_mode == 0)
    {
        uint num_slots = g_lubo.screen_space_num_layers;
        MaskData mask = calculate_masks(mask_val, num_slots);
        uint sample_to_kill = 0;
        if (mask.invisible_samples_mask != 0)
            sample_to_kill = choose_random_sample(mask.invisible_samples_mask, num_slots);
        else if (mask.visible_samples_mask != 0)
            sample_to_kill = choose_random_sample(mask.occupancy_mask, num_slots);

        // Disable the selected sample
        mask_val = disable_sample(mask_val, sample_to_kill, num_slots);
    }
    else
    {
        uint num_slots = g_lubo.screen_M * g_lubo.screen_M;
        MaskData mask = calculate_masks(mask_val, num_slots);
        uint sample_to_kill = 0;
        if (mask.invisible_samples_mask != 0)
            sample_to_kill = choose_random_sample(mask.invisible_samples_mask, num_slots);
        else if (mask.occupancy_mask != 0)
            sample_to_kill = choose_random_sample(mask.occupancy_mask, num_slots);

        // Disable the selected sample
        mask_val = disable_sample(mask_val, sample_to_kill, num_slots);
    }
}

uint create_sample(inout uint mask_val, uint3 gId)
{
    if (mask_mode == 0)
    {
        // choose either an empty slot or, failing that, an invisible sample slot
        uint num_slots = g_lubo.screen_space_num_layers;
        MaskData mask = calculate_masks(mask_val, num_slots);
        uint chosen_sample = 0u;
        if (countbits(mask.active_samples_mask) < num_slots)
            chosen_sample = choose_random_sample(~mask.active_samples_mask, num_slots);
        else
            chosen_sample = choose_random_sample(~mask.visibility_mask, num_slots);

        mask_val = make_sample(mask_val, chosen_sample, num_slots);

        return chosen_sample;
    }
    else
    {
        // choose either an empty slot or, failing that, an invisible sample slot
        uint num_slots = g_lubo.screen_M * g_lubo.screen_M;
        MaskData mask = calculate_masks(mask_val, num_slots);
        uint chosen_sample = 0u;

        if (countbits(mask.active_samples_mask) < g_lubo.screen_space_num_layers)
            chosen_sample = choose_random_sample(~mask.active_samples_mask, num_slots);
        else
            chosen_sample = choose_random_sample(~mask.visibility_mask, num_slots);

        mask_val = make_sample(mask_val, chosen_sample, num_slots);

        return chosen_sample;
    }
}
