#include <stdio.h>
#include <stdlib.h>
#include <inttypes.h>
#include <assert.h>

inline uint32_t sign_extend(uint32_t word, uint32_t size)
{
    const uint32_t mask = 1U << (size - 1);
    return (word ^ mask) - mask;
}

struct Instruction
{
    uint8_t  opcode;
    uint8_t  rs1;
    uint8_t  rs2;
    uint8_t  rd;
    uint8_t  funct3;
    uint8_t  funct7;
    uint32_t imm;
};

struct Instruction decode_r_type(uint32_t word)
{
    struct Instruction instruction = {0};
    instruction.opcode =  word        & 0x7F;
    instruction.rd     = (word >> 7)  & 0x1F;
    instruction.funct3 = (word >> 12) & 0x07;
    instruction.rs1    = (word >> 15) & 0x1F;
    instruction.rs2    = (word >> 20) & 0x1F;
    instruction.funct7 =  word >> 25;
    return instruction;
};

struct Instruction decode_i_type(uint32_t word)
{
    struct Instruction instruction = {0};
    instruction.opcode =  word        & 0x7F;
    instruction.rd     = (word >> 7)  & 0x1F;
    instruction.funct3 = (word >> 12) & 0x07;
    instruction.rs1    = (word >> 15) & 0x1F;
    instruction.imm    = sign_extend(word >> 20, 12);
    return instruction;
};

struct Instruction decode_s_type(uint32_t word)
{
    struct Instruction instruction = {0};
    instruction.opcode =  word        & 0x7F;
    instruction.funct3 = (word >> 12) & 0x07;
    instruction.rs1    = (word >> 15) & 0x1F;
    instruction.rs2    = (word >> 20) & 0x1F;
    instruction.imm    = sign_extend(((word >> 7) & 0x1F) | (word >> 25), 12);
    return instruction;
};

struct Instruction decode_b_type(uint32_t word)
{
    struct Instruction instruction = decode_s_type(word);
    instruction.imm = ((instruction.imm << 11) & 0x800) | (instruction.imm & 0xfffff7ff);
    return instruction;
};

struct Instruction decode_u_type(uint32_t word)
{
    struct Instruction instruction = {0};
    instruction.opcode =  word        & 0x7F;
    instruction.rd     = (word >> 7)  & 0x1F;
    instruction.imm    =  word        & 0xFFFFF000;
    return instruction;
};

struct Instruction decode_j_type(uint32_t word)
{
    struct Instruction instruction = {0};
    instruction.opcode =  word        & 0x7F;
    instruction.rd     = (word >> 7)  & 0x1F;
    instruction.imm    = sign_extend(
        ((word & 0x80000000) >> 11) |
        ((word & 0x000FF000) >> 0)  |
        ((word & 0x00100000) >> 9)  |
        ((word & 0x7FE00000) >> 20), 21);
    return instruction;
}

struct Hart
{
    uint32_t pc;
    uint32_t regs[32];
};

void execute_op_imm(struct Hart* hart, uint32_t instruction)
{
    struct Instruction inst = decode_i_type(instruction);
    switch (inst.funct3)
    {
        case 0: // ADDI
            hart->regs[inst.rd] = hart->regs[inst.rs1] + inst.imm;
            break;
        case 1: // SLLI
            hart->regs[inst.rd] = hart->regs[inst.rs1] << (inst.imm & 0x1F);
            break;
        case 2: // SLTI
            hart->regs[inst.rd] = (int32_t)hart->regs[inst.rs1] < (int32_t)inst.imm ? 1 : 0;
            break;
        case 3: // SLTIU
            hart->regs[inst.rd] = hart->regs[inst.rs1] < inst.imm ? 1 : 0;
            break;
        case 4: // XORI
            hart->regs[inst.rd] = hart->regs[inst.rs1] ^ inst.imm;
            break;
        case 5: // SRLI, SRAI
        {
            const uint32_t shamt = inst.imm & 0x1F;
            uint32_t res = hart->regs[inst.rs1] >> shamt;
            if ((inst.imm & 0x400) && shamt > 0) { res = sign_extend(res, 32 - shamt); }
            hart->regs[inst.rd] = res;
            break;
        }
        case 6: // ORI
            hart->regs[inst.rd] = hart->regs[inst.rs1] | inst.imm;
            break;
        case 7: // ANDI
            hart->regs[inst.rd] = hart->regs[inst.rs1] & inst.imm;
            break;
        default:
            assert(!"Unhandled OP-IMM");
    }
}

void execute_op(struct Hart* hart, uint32_t instruction)
{
    struct Instruction inst = decode_r_type(instruction);
    switch (inst.funct3)
    {
        case 0: // ADD, SUB
            if (instruction & 0x40000000)
            {
                hart->regs[inst.rd] = hart->regs[inst.rs1] - hart->regs[inst.rs2];
            }
            else
            {
                hart->regs[inst.rd] = hart->regs[inst.rs1] + hart->regs[inst.rs2];
            }
            break;
        case 1: // SLL
            hart->regs[inst.rd] = hart->regs[inst.rs1] << (hart->regs[inst.rs2] & 0x1F);
            break;
        case 2: // SLT
            hart->regs[inst.rd] = (int32_t)hart->regs[inst.rs1] < (int32_t)hart->regs[inst.rs2] ? 1 : 0;
            break;
        case 3: // SLTU
            hart->regs[inst.rd] = hart->regs[inst.rs1] < hart->regs[inst.rs2] ? 1 : 0;
            break;
        case 4: // XOR
            hart->regs[inst.rd] = hart->regs[inst.rs1] ^ hart->regs[inst.rs2];
            break;
        case 5: // SRL, SRA
        {
            const uint32_t shamt = hart->regs[inst.rs2] & 0x1F;
            uint32_t res = hart->regs[inst.rs1] >> shamt;
            if ((instruction & 0x40000000) && shamt > 0) { res = sign_extend(res, 32 - shamt); }
            hart->regs[inst.rd] = res;
            break;
        }
        case 6: // OR
            hart->regs[inst.rd] = hart->regs[inst.rs1] | hart->regs[inst.rs2];
            break;
        case 7: // AND
            hart->regs[inst.rd] = hart->regs[inst.rs1] & hart->regs[inst.rs2];
            break;
        default:
            assert(!"Unhandled OP-IMM");
    }
}

void execute(struct Hart* hart, uint32_t instruction)
{
    switch (instruction & 0x7f)
    {
        case 0x13:
            execute_op_imm(hart, instruction);
            break;
        case 0x17: // AUIPC
        {
            struct Instruction inst = decode_u_type(instruction);
            hart->regs[inst.rd] = inst.imm + hart->pc;
            break;
        }
        case 0x33:
            execute_op(hart, instruction);
            break;
        case 0x37: // LUI
        {
            struct Instruction inst = decode_u_type(instruction);
            hart->regs[inst.rd] = inst.imm;
            break;
        }
        default:
            assert(!"Unhandled opcode");
    }
}

void test_addi()
{
    struct Hart hart = {0};
    
    execute(&hart, 0x00500093); // addi x1, x0, 5
    assert(hart.regs[1] == 5);

    execute(&hart, 0xffe00093); // addi, x1, x0, -2
    assert(hart.regs[1] == 0xfffffffe);
}

void test_slti_sltiu()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 5;

    execute(&hart, 0x00f0b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 1);

    execute(&hart, 0x0050b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x0010b113); // sltiu x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x00f0a113); // slti x2, x1, 15
    assert(hart.regs[2] == 1);

    execute(&hart, 0x0050a113); // slti x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0x0010a113); // slti x2, x1, 15
    assert(hart.regs[2] == 0);

    execute(&hart, 0xffb0a113); // slti x2, x1, -5
    assert(hart.regs[2] == 0);

    hart.regs[1] = (uint32_t)-20;

    execute(&hart, 0xffb0a113); // slti x2, x1, -5
    assert(hart.regs[2] == 1);
}

void test_andi_ori_xori()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 6;

    execute(&hart, 0x00c0c113); // xori x2, x1, 12
    assert(hart.regs[2] == 10);
    
    execute(&hart, 0x00c0e113); // ori x2, x1, 12
    assert(hart.regs[2] == 14);
    
    execute(&hart, 0x00c0f113); // andi x2, x1, 12
    assert(hart.regs[2] == 4);
}

void test_slli_srli_srai()
{
    struct Hart hart = {0};
    
    hart.regs[1] = 6;

    execute(&hart, 0x00209113); // slli x2, x1, 2
    assert(hart.regs[2] == 24);

    execute(&hart, 0x0020d113); // srli x2, x1, 2
    assert(hart.regs[2] == 1);

    execute(&hart, 0x4020d113); // srai x2, x1, 2
    assert(hart.regs[2] == 1);

    hart.regs[1] = (uint32_t)-6;

    execute(&hart, 0x0020d113); // srli x2, x1, 2
    assert(hart.regs[2] == 0x3FFFFFFE);

    execute(&hart, 0x4020d113); // srai x2, x1, 2
    assert(hart.regs[2] == 0xFFFFFFFE);
}

void test_lui_auipc()
{
    struct Hart hart = {0};
    
    execute(&hart, 0x0007b0b7); // lui x1, 503808
    assert(hart.regs[1] == 503808);
    
    execute(&hart, 0x0007b097); // auipc x1, 503808
    assert(hart.regs[1] == 503808);
    
    hart.pc = 12;
    execute(&hart, 0x0007b097); // auipc x1, 503808
    assert(hart.regs[1] == 503820);
}

void test_op()
{
    struct Hart hart = {0};

    hart.regs[1] = 3;
    hart.regs[2] = 4;
    hart.regs[4] = (uint32_t)-1;

    execute(&hart, 0x002081b3); // add, x3, x1, x2
    assert(hart.regs[3] == 7);

    execute(&hart, 0x402081b3); // sub x3, x1, x2
    assert(hart.regs[3] == 0xFFFFFFFF);

    execute(&hart, 0x0020a1b3); // slt x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x001121b3); // slt x3, x2, 1
    assert(hart.regs[3] == 0);

    execute(&hart, 0x001131b3); // sltu x3, x2, x1
    assert(hart.regs[3] == 0);

    execute(&hart, 0x0020b1b3); // sltu x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x0040a1b3); // slt x3, x1, x4
    assert(hart.regs[3] == 0);
    
    hart.regs[1] = 6;
    hart.regs[2] = 12;

    execute(&hart, 0x0020e1b3); // or x3, x1, x2
    assert(hart.regs[3] == 14);

    execute(&hart, 0x0020c1b3); // xor x3, x1, x2
    assert(hart.regs[3] == 10);

    execute(&hart, 0x0020f1b3); // and x3, x1, x2
    assert(hart.regs[3] == 4);
    
    hart.regs[1] = 6;
    hart.regs[2] = 2;

    execute(&hart, 0x002091b3); // sll x3, x1, x2
    assert(hart.regs[3] == 24);

    execute(&hart, 0x0020d1b3); // srl x3, x1, x2
    assert(hart.regs[3] == 1);

    execute(&hart, 0x4020d1b3); // sra x3, x1, x2
    assert(hart.regs[3] == 1);
    
    hart.regs[1] = (uint32_t)-6;
    hart.regs[2] = 2;

    execute(&hart, 0x4020d1b3); // sra x3, x1, x2
    assert(hart.regs[3] == 0xFFFFFFFE);
}

void test()
{
    test_addi();
    test_slti_sltiu();
    test_andi_ori_xori();
    test_slli_srli_srai();
    test_lui_auipc();
    test_op();
}

int main(int argc, char* argv[])
{
    test();
    return EXIT_SUCCESS;
}