/*
 * pmpfun.c 
 *
 * Author: Andy Ying
 *   Date: April 20, 2007
 */

#include <ntddk.h>

#include "pmpfun.h"

const WCHAR DeviceLink[] = L"\\DosDevices\\PmpFun";
const WCHAR DeviceName[] = L"\\Device\\PmpFun";

PDEVICE_OBJECT PmpDevice;

NTSTATUS 
DriverUnload(PDRIVER_OBJECT DriverObject)
{
	UNICODE_STRING DeviceLinkString;
	PDEVICE_OBJECT NextObj;

	NextObj = DriverObject->DeviceObject;

	if (NextObj != NULL) {
		RtlInitUnicodeString(&DeviceLinkString, DeviceLink);
		IoDeleteSymbolicLink(&DeviceLinkString);

		IoDeleteDevice(DriverObject->DeviceObject);
		
		return STATUS_SUCCESS; 
	}

	DbgPrint("PmpFun: Unloaded\n");
	
	return STATUS_SUCCESS;
}

NTSTATUS
PmpControl(IN PDEVICE_OBJECT DeviceObject, IN PIRP Irp)
{
	PIO_STACK_LOCATION IrpStack;
	
	ULONG ProcessId;
	PULONG ProtectPid;
	PULONG BaseAddr;

	IN PVOID InputBuffer;
	IN ULONG InputBufferLength;
	OUT PVOID OutputBuffer;
	OUT ULONG OutputBufferLength;

	ULONG IoControlCode;
	NTSTATUS NtStatus;

	NtStatus = Irp->IoStatus.Status = STATUS_SUCCESS;
	Irp->IoStatus.Information = 0;

	IrpStack = IoGetCurrentIrpStackLocation(Irp);

	InputBuffer = Irp->AssociatedIrp.SystemBuffer;
	InputBufferLength = IrpStack->Parameters.DeviceIoControl.InputBufferLength;
	OutputBuffer = Irp->AssociatedIrp.SystemBuffer;
	OutputBufferLength = IrpStack->Parameters.DeviceIoControl.OutputBufferLength;
	IoControlCode = IrpStack->Parameters.DeviceIoControl.IoControlCode;

	switch (IoControlCode) {
		case IOCTL_SET_PMP:
			DbgPrint("PmpFun: setpmp inbuflen = %i || inbuf = 0x%x\n", InputBufferLength, InputBuffer);

			if ((InputBufferLength < sizeof(ULONG)) || (InputBuffer == NULL)) {
				DbgPrint("Pmpfun: bad InputBuffer\n");
				
				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}

			DbgPrint("Pmpfun: InputBuffer = 0x%x\n", InputBuffer);

			ProcessId = *((PULONG) InputBuffer);

			DbgPrint("Pmpfun: process Id %u\n", ProcessId);

			BaseAddr = PmpGetProcessByProcessId(ProcessId);
			
			if (BaseAddr == NULL) {
				DbgPrint("Pmpfun: failed to find PID\n");

				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}

			ProtectPid = (PULONG)((ULONG) BaseAddr + PROTECTED_PROCESS);

			DbgPrint("Pmpfun: setting protected process mode on %u\n", ProcessId);
			DbgPrint("Pmpfun: previous 0x%x\n", *ProtectPid);

			// Set the bit to 1 
			*ProtectPid = *ProtectPid | PROTECTED_BIT;
			
			DbgPrint("Pmpfun: after    0x%x\n", *ProtectPid);

			break;
		case IOCTL_RESET_PMP:
			DbgPrint("PmpFun: resetpmp inbuflen = %i || inbuf = 0x%x\n", InputBufferLength, InputBuffer);

			if ((InputBufferLength < sizeof(ULONG)) || (InputBuffer == NULL)) {
				DbgPrint("Pmpfun: bad InputBuffer\n");
				
				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}
			
			DbgPrint("Pmpfun: InputBuffer = 0x%x\n", InputBuffer);

			ProcessId = *((PULONG) InputBuffer);

			DbgPrint("Pmpfun: disabling protected process mode on %u\n", ProcessId);
			DbgPrint("Pmpfun: process Id %u\n", ProcessId);

			BaseAddr = PmpGetProcessByProcessId(ProcessId);

			if (BaseAddr == NULL) {
				DbgPrint("Pmpfun: failed to find PID\n");

				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}

			ProtectPid = (PULONG)((ULONG) BaseAddr + PROTECTED_PROCESS);
			
			DbgPrint("Pmpfun: previous 0x%x\n", *ProtectPid);

			// Set the bit to 0
			*ProtectPid = *ProtectPid & ~PROTECTED_BIT;
			
			DbgPrint("Pmpfun: after    0x%x\n", *ProtectPid);

			break;
		case IOCTL_GET_PMP:
			DbgPrint("Pmpfun: InBufLen = %u, InBuf = 0x%x\n", InputBufferLength, InputBuffer);

			if ((InputBufferLength < sizeof(ULONG)) || (InputBuffer == NULL)) {
				DbgPrint("Pmpfun: bad InputBuffer\n");
				
				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}
			
			DbgPrint("Pmpfun: InputBuffer = 0x%x\n", InputBuffer);

			if ((OutputBufferLength < sizeof(ULONG)) || (OutputBuffer == NULL)) {
				DbgPrint("Pmpfun: bad OutputBuffer\n");

				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}
			
			ProcessId = *((PULONG) InputBuffer);

			DbgPrint("Pmpfun: process Id %u\n", ProcessId);

			BaseAddr = PmpGetProcessByProcessId(ProcessId);
			
			if (BaseAddr == NULL) {
				DbgPrint("Pmpfun: failed to find PID\n");

				Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

				break;
			}
			
			ProtectPid = (PULONG)((ULONG) BaseAddr + PROTECTED_PROCESS);
			
			*((PULONG) OutputBuffer) = (ULONG) *ProtectPid;

			DbgPrint("Pmpfun: returned 0x%x\n", *ProtectPid);
			break;
		default:
			DbgPrint("Pmpfun: invalid IOCTL\n");

			Irp->IoStatus.Status = STATUS_INVALID_BUFFER_SIZE;

			break;
	}

	IoCompleteRequest(Irp, IO_NO_INCREMENT);

	return Irp->IoStatus.Status;
}

PULONG
PmpGetProcessByProcessId(LONG Pid)
{
	PULONG BaseAddr = (PULONG) PsGetCurrentProcess();
	
	struct LIST_ENTRY *HeadList;
	struct LIST_ENTRY *Current;

	if (BaseAddr == NULL)
		return NULL;

	HeadList = (struct LIST_ENTRY *)((ULONG) BaseAddr + PROCESS_LIST);
	Current = HeadList->Flink;

	DbgPrint("Pmpfun: TARGET PID [%i]\n", Pid);

	while (Current != HeadList) {
		PULONG Base = (PULONG)((ULONG) Current - PROCESS_LIST);
		LONG UniquePid = *((PLONG)((ULONG) Base + PROCESS_ID));
		
		DbgPrint("Pmpfun: [%i] %s\n", UniquePid, (ULONG) Base + IMAGE_NAME);

		if (UniquePid == Pid) 
			return (PULONG) Base;

		Current = Current->Flink;
	}

	return NULL;
}

NTSTATUS
PmpStubIrp(IN PDEVICE_OBJECT DeviceObject, IN PIRP Irp)
{
	Irp->IoStatus.Status = STATUS_SUCCESS;
	IoCompleteRequest(Irp, IO_NO_INCREMENT);

	return Irp->IoStatus.Status;
}

NTSTATUS 
DriverEntry(IN PDRIVER_OBJECT DriverObject, IN PUNICODE_STRING RegistryPath)
{
	ULONG i;
	NTSTATUS NtStatus;
	UNICODE_STRING DeviceNameString;
	UNICODE_STRING DeviceLinkString;

	RtlInitUnicodeString(&DeviceNameString, DeviceName);
	RtlInitUnicodeString(&DeviceLinkString, DeviceLink);

	NtStatus = IoCreateDevice(DriverObject, 0, &DeviceNameString, 
							  0x00002a7b, 0, TRUE, &PmpDevice);

	if (!NT_SUCCESS(NtStatus)) {
		DbgPrint("Pmpfun: failed to create device\n");

		return NtStatus;
	}

	NtStatus = IoCreateSymbolicLink(&DeviceLinkString, &DeviceNameString);

	if (!NT_SUCCESS(NtStatus)) {
		IoDeleteDevice(DriverObject->DeviceObject);
		DbgPrint("Pmpfun: failed to create symbolic link\n");

		return NtStatus;
	}

	DbgPrint("Pmpfun: loaded\n");

	for (i = 0; i < IRP_MJ_MAXIMUM_FUNCTION; i++) {
		DriverObject->MajorFunction[i] = PmpStubIrp;
	}
	
	DriverObject->MajorFunction[IRP_MJ_DEVICE_CONTROL] = PmpControl;

	DriverObject->DriverUnload = DriverUnload;
	
	return STATUS_SUCCESS;
}
