#include <stdlib.h>
#include <inttypes.h>
#include <assert.h>
#include <stdbool.h>

#include "emulator/hart.h"

static void test_addi()
{
    struct Hart hart = {0};
    
    execute(&hart, 0x00500093); // addi x1, x0, 5
    assert(hart.regs[1] == 5);

    execute(&hart, 0xffe00093); // addi, x1, x0, -2
    assert(hart.regs[1] == 0xfffffffe);
}

static void test_slti_sltiu()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 5;

    execute(&hart, 0x00f0b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 1);

    execute(&hart, 0x0050b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x0010b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x00f0a113); // slti x2, x1, 15
    assert(hart.regs[2] == 1);

    execute(&hart, 0x0050a113); // slti x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x0010a113); // slti x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0xffb0a113); // slti x2, x1, -5
    assert(hart.regs[2] == 0);

    hart.regs[1] = (uint32_t)-20;

    execute(&hart, 0xffb0a113); // slti x2, x1, -5
    assert(hart.regs[2] == 1);
}

static void test_andi_ori_xori()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 6;

    execute(&hart, 0x00c0c113); // xori x2, x1, 12
    assert(hart.regs[2] == 10);
    
    execute(&hart, 0x00c0e113); // ori x2, x1, 12
    assert(hart.regs[2] == 14);
    
    execute(&hart, 0x00c0f113); // andi x2, x1, 12
    assert(hart.regs[2] == 4);
}

static void test_slli_srli_srai()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 6;

    execute(&hart, 0x00209113); // slli x2, x1, 2
    assert(hart.regs[2] == 24);

    execute(&hart, 0x0020d113); // srli x2, x1, 2
    assert(hart.regs[2] == 1);

    execute(&hart, 0x4020d113); // srai x2, x1, 2
    assert(hart.regs[2] == 1);

    hart.regs[1] = (uint32_t)-6;

    execute(&hart, 0x0020d113); // srli x2, x1, 2
    assert(hart.regs[2] == 0x3FFFFFFE);

    execute(&hart, 0x4020d113); // srai x2, x1, 2
    assert(hart.regs[2] == 0xFFFFFFFE);
}

static void test_lui_auipc()
{
    struct Hart hart = {0};
    
    hart.pc = 0;
    execute(&hart, 0x0007b0b7); // lui x1, 503808
    assert(hart.regs[1] == 503808);
    
    hart.pc = 0;
    execute(&hart, 0x0007b097); // auipc x1, 503808
    assert(hart.regs[1] == 503808);
    
    hart.pc = 12;
    execute(&hart, 0x0007b097); // auipc x1, 503808
    assert(hart.regs[1] == 503820);
}

static void test_op()
{
    struct Hart hart = {0};

    hart.regs[1] = 3;
    hart.regs[2] = 4;
    hart.regs[4] = (uint32_t)-1;

    execute(&hart, 0x002081b3); // add, x3, x1, x2
    assert(hart.regs[3] == 7);

    execute(&hart, 0x402081b3); // sub x3, x1, x2
    assert(hart.regs[3] == 0xFFFFFFFF);

    execute(&hart, 0x0020a1b3); // slt x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x001121b3); // slt x3, x2, 1
    assert(hart.regs[3] == 0);

    execute(&hart, 0x001131b3); // sltu x3, x2, x1
    assert(hart.regs[3] == 0);

    execute(&hart, 0x0020b1b3); // sltu x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x0040a1b3); // slt x3, x1, x4
    assert(hart.regs[3] == 0);
    
    hart.regs[1] = 6;
    hart.regs[2] = 12;

    execute(&hart, 0x0020e1b3); // or x3, x1, x2
    assert(hart.regs[3] == 14);

    execute(&hart, 0x0020c1b3); // xor x3, x1, x2
    assert(hart.regs[3] == 10);

    execute(&hart, 0x0020f1b3); // and x3, x1, x2
    assert(hart.regs[3] == 4);
    
    hart.regs[1] = 6;
    hart.regs[2] = 2;

    execute(&hart, 0x002091b3); // sll x3, x1, x2
    assert(hart.regs[3] == 24);

    execute(&hart, 0x0020d1b3); // srl x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x4020d1b3); // sra x3, x1, x2
    assert(hart.regs[3] == 1);
    
    hart.regs[1] = (uint32_t)-6;
    hart.regs[2] = 2;

    execute(&hart, 0x4020d1b3); // sra x3, x1, x2
    assert(hart.regs[3] == 0xFFFFFFFE);
}

static void test_jal()
{
    struct Hart hart = {0};

    hart.pc = 12;

    execute(&hart, 0x12c000ef); // jal x1, 300
    assert(hart.regs[1] == 16);
    assert(hart.pc == 312);

    execute(&hart, 0xed5ff0ef); // jal x1, -300
    assert(hart.regs[1] == 316);
    assert(hart.pc == 12);
}

static void test_jalr()
{
    struct Hart hart = {0};

    hart.pc = 12;
    hart.regs[1] = 300;

    execute(&hart, 0x00a08167); // jalr x2, 10(x1)
    assert(hart.regs[2] == 16);
    assert(hart.pc == 310);

    execute(&hart, 0xff608167); // jalr x2, -10(x1)
    assert(hart.regs[2] == 314);
    assert(hart.pc == 290);
}

static void test_branch()
{
    struct Hart hart = {0};

    hart.pc = 100;
    hart.regs[1] = 2;
    hart.regs[2] = 0xFFFFFFFC;

    execute(&hart, 0x00208c63); // beq x1, x2, 24
    assert(hart.pc == 104);

    hart.pc = 100;
    execute(&hart, 0x00209c63); // bne x1, x2, 24
    assert(hart.pc == 124);

    hart.pc = 100;
    execute(&hart, 0x0020cc63); // blt x1, x2, 24
    assert(hart.pc == 124);

    hart.pc = 100;
    execute(&hart, 0x0020dc63); // bge x1, x2, 24
    assert(hart.pc == 104);

    hart.pc = 100;
    execute(&hart, 0x0020ec63); // bltu x1, x2, 24
    assert(hart.pc == 104);

    hart.pc = 100;
    execute(&hart, 0x0020fc63); // bgeu x1, x2, 24
    assert(hart.pc == 124);
}

void test()
{
    test_addi();
    test_slti_sltiu();
    test_andi_ori_xori();
    test_slli_srli_srai();
    test_lui_auipc();
    test_op();
    test_jal();
    test_jalr();
    test_branch();
}

int main(int argc, char* argv[])
{
    test();
    return EXIT_SUCCESS;
}