#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>

#include "get_address.h"
#include "xex.h"

#define SELF "xexamine"

/* #define CLASSIFY_DEBUG */

/*
	show all segments of a xex file, including:
	has ffff header or not
	start address, end address, length
	whether the segment contains code or just data (heuristics)
	checksum of the segment contents
*/

/* crc32() and crc32_for_byte() come from public domain code:
	http://home.thep.lu.se/~bjorn/crc/
*/

uint32_t crc32_for_byte(uint32_t r) {
	int j;
	for(j = 0; j < 8; ++j)
		r = (r & 1? 0: (uint32_t)0xEDB88320L) ^ r >> 1;
	return r ^ (uint32_t)0xFF000000L;
}

void crc32(const void *data, size_t n_bytes, uint32_t* crc) {
	size_t i;
	static uint32_t table[0x100];
	if(!*table)
		for(i = 0; i < 0x100; ++i)
			table[i] = crc32_for_byte(i);
	for(i = 0; i < n_bytes; ++i)
		*crc = table[(uint8_t)*crc ^ ((uint8_t*)data)[i]] ^ *crc >> 8;
}

void usage(int status) {
	printf("Usage: " SELF " [-h] [-d] [-v] [-s segment] file.xex\n");
	exit(status);
}

#define CL_DATA 0
#define CL_OPCODE 1
#define CL_OPERAND 2

/* 3 tables used by classify_seg() */

int opcode_valid[] = {
/* 0  1  2  3  4  5  6  7  8  9  a  b  c  d  e  f */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, /* 0 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* 1 */
	1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* 2 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* 3 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* 4 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* 5 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* 6 */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* 7 */
	0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, /* 8 */
	1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, /* 9 */
	1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* a */
	1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* b */
	1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* c */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* d */
	1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, /* e */
	1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, /* f */
};

int opcode_lengths[] = {
/* 0  1  2  3  4  5  6  7  8  9  a  b  c  d  e  f */
	1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
	3, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
	1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
	1, 2, 1, 1, 1, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
	1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 2, 2, 2, 1, 1, 3, 1, 1, 1, 3, 1, 1,
	2, 2, 2, 1, 2, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 2, 2, 2, 1, 1, 3, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
	2, 2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 3, 3, 3, 1,
	2, 2, 1, 1, 1, 2, 2, 1, 1, 3, 1, 1, 1, 3, 3, 1,
};

/* control transfers: branches, JMP abs, JMP (ind), RTS, RTI.
   JSR doesn't count! */
int opcode_is_ctlxfr[] = {
/* 0  1  2  3  4  5  6  7  8  9  a  b  c  d  e  f */
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
	1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
};

void find_dlist(const unsigned char *mem, unsigned char *map, int len) {
	int i, dlstart = -1, dlend = -1;

	for(i = 2; i < len-2; i++) {
		if(mem[i] == 'p' && mem[i-1] == 'p' && mem[i-2] == 'p') {
			dlstart = i - 2;
		}

		if(dlstart > -1 && mem[i] == 'A') {
			dlend = i;
			break;
		}
	}
	if(dlstart != -1 && dlend != -1) {
#ifdef CLASSIFY_DEBUG
		printf("display list found, offsets %d - %d\n", dlstart, dlend);
#endif
		for(i = dlstart; i <= dlend; i++)
			map[i] = CL_DATA;
	}
}

/*
	classify_seg() returns code percentage estimate.
	possible strategies:
	if seg is only 1 byte, it's data (return 0)
	iterate over segment, semi-disassemble, classify each
	byte as opcode, operand, or data.
	instructions that aren't control transfers (aka jmp, jsr, rts, rti,
	branches) that are immediately followed by non-code, will get marked
	as data (back to the last transfer instruction).
	branches that branch back before the start of the segment are data?
	branch/jmp/jsr whose target is data, are also data?
	but jmp/jsr outside of the segment can't be assumed data...
	long runs (>=8) of the same byte value are data.
*/

int classify_seg(const xex_segment seg) {
	int i, j, addr, byte, oplen, target, last_cltxfr = 0, changed;
	int runstart = 0, runbyte = -1, runcount = 0;
	float f;
	unsigned char map[65536];

	memset(map, 0, 65535);

	/* pass 1: mark valid opcodes as CL_OPCODE, their operands as
		CL_OPERAND, and anything else as CL_DATA. */
	for(i = 0; i < seg.len; ) {
		byte = seg.object[i];
		oplen = opcode_lengths[byte];
		if(opcode_valid[byte]) {
			map[i] = CL_OPCODE;
			if(oplen >= 2) map[i + 1] = CL_OPERAND;
			if(oplen == 3) map[i + 2] = CL_OPERAND;
		} else {
			map[i] = CL_DATA;
		}
		i += oplen;
	}

	/* pass 1.5: if there's a display list, it's data */
	find_dlist(seg.object, map, seg.len);

	/* pass 3: runs of >=3 of the same byte value are data, unless
		they're ASL A, LSR A, or NOP. */
	for(i = 0; i < seg.len; i++) {
		byte = seg.object[i];
		if(byte == runbyte) {
			runcount++;
		} else {
			if(
					runcount > 8 ||
					(runcount >= 3 && !(runbyte == 0x0a || runbyte == 0x4a || runbyte == 0xea))
			  )
			{
				#ifdef CLASSIFY_DEBUG
				printf("run of %d bytes, $%02x, at %d\n", runcount, runbyte, runstart);
				#endif
				for(j = runstart; j < i; j++)
					map[j] = CL_DATA;
			}
			runstart = i;
			runbyte = byte;
			runcount = 0;
		}
		/*
		printf("got here, i=%d, runbyte=%02x, runstart=%d, runcount=%d\n", i, runbyte, runstart, runcount);
		*/
	}

	/* pass 4: code that doesn't branch/jump/return and runs into data
		gets marked as data. */
	runcount = runstart = 0;
	do {
		runcount++;
		changed = 0;
		for(i = 0; i < seg.len; i++) {
			if(map[i] == CL_OPCODE && opcode_is_ctlxfr[seg.object[i]]) {
				last_cltxfr = i;
				#ifdef CLASSIFY_DEBUG
				/*
				printf("last_cltxfr = %04x, opcode %02x;\n", last_cltxfr, seg.object[i]);
				*/
				#endif
			} else if(map[i] == CL_DATA) {
				#ifdef CLASSIFY_DEBUG
				/*
				printf("marking range as data: %04x - %04x\n", last_cltxfr, i);
				*/
				#endif
				for(j = last_cltxfr; j < i; j++) {
					map[j] = CL_DATA;
				}
				last_cltxfr = i;
			}
		}

		/* pass 4: branch and jmp abs instructions whose target is data,
			are also data. repeats until nothing is changed. */
		for(i = 0; i < seg.len; i++) {
			addr = seg.start_addr + i;
			byte = seg.object[i];
			target = -1;
			if(map[i] == CL_OPCODE) switch(byte) {
				case 0x4c: /* JMP absolute */
				case 0x20: /* JSR absolute */
					target = addr + (seg.object[i + 1] | (seg.object[i + 2] << 8));
					if((target < addr) || (target > (addr + seg.len)))
						target = -1; /* jsr/jmp out of segment */
					else {
						if(map[target - addr] != CL_OPCODE) {
							map[i] = map[i + 1] = map[i + 2] = CL_DATA;
							changed = 1;
							runstart += 3;
						}
					}
					break;

				case 0x10: /* BPL */
				case 0x30: /* BMI */
				case 0x50: /* BVC */
				case 0x70: /* BVS */
				case 0x90: /* BCC */
				case 0xb0: /* BCS */
				case 0xd0: /* BNE */
				case 0xf0: /* BEQ */
					target = addr + i + 2 + ((signed char)seg.object[i + 1]);
					if((target < addr) || (target > (addr + seg.len))) {
						/* branch out of segment, assume data */
						target = -1;
						map[i] = map[i + 1] = CL_DATA;
						runstart += 2;
						changed = 1;
					} else if(map[target - addr] != CL_OPCODE) {
						/* branch to data! */
						map[i] = map[i + 1] = CL_DATA;
						runstart += 2;
						changed = 1;
					}
					break;

				default:
					break;
			}
		}
	} while(changed);
#ifdef CLASSIFY_DEBUG
	printf("pass 4 ran %d times, changed %d bytes\n", runcount, runstart);
#endif

	/* last pass: calculate opcode/operand percentage */
	j = 0;
	for(i = 0; i < seg.len; i++) {
		if(map[i] != CL_DATA) j++;
	}

#ifdef CLASSIFY_DEBUG
	for(i = 0; i < seg.len; i++) {
		if(i % 16 == 0) {
			printf("\n%04x: ", i);
		}
		printf("%02x/%d ", seg.object[i], map[i]);
	}
	printf("\n");
#endif

	f = (float)j / (float)seg.len * 100.0;
	return (int)f;
}

int main(int argc, char **argv) {
	FILE *f;
	xex_segment seg;
	unsigned char buffer[65536];
	char *filename;
	int opt, offset, segcount = 0, only_segment = 0, header_printed = 0, decimal = 0, print_filenames = 0, r, i;
	uint32_t crc;

	while((opt = getopt(argc, argv, "vhs:d")) != -1) {
		switch(opt) {
			case 'd':
				decimal = 1;
				break;
			case 's':
				if( (only_segment = get_address(SELF, optarg)) < 0 )
					exit(1);
				if(!only_segment) {
					fprintf(stderr, SELF ": 0 is not a valid segment number (they start with 1).\n");
					exit(1);
				}
				break;
			case 'v':
				xex_verbose = 1;
				break;
			case 'h':
				usage(0);
				break;
			default:
				usage(1);
				break;
		}
	}

	if(optind >= argc) {
		fprintf(stderr, SELF ": no xex file argument.\n");
		usage(1);
	}

	if(argc > optind + 1) print_filenames = 1;
	while(optind < argc) {
		filename = argv[optind++];
		if( !(f = fopen(filename, "rb")) ) {
			fprintf(stderr, "%s: ", SELF);
			perror(filename);
			exit(1);
		}
		if(print_filenames) printf("%s:\n", filename);

		seg.object = buffer;

		offset = segcount = 0;
		while(xex_fread_seg(&seg, f)) {
			segcount++;

			crc = 0;
			crc32(seg.object, seg.len, &crc);

			if(!only_segment || (only_segment == segcount)) {
				if(!header_printed) {
					printf("Seg | Offset | Start | End   | Bytes | CRC32    | Code%% | Run/Init?\n");
					header_printed++;
				}

				printf((decimal ? 
							"%3d | %6d | %5d | %5d | %5d | %08x | " :
							"%3d | %6d | $%04x | $%04x | %5d | %08x | %4d%% | "),
						segcount, offset, seg.start_addr, seg.end_addr, seg.len, crc,
						classify_seg(seg));

				r = xex_get_run_addr(&seg);
				i = xex_get_init_addr(&seg);
				if(r != -1) {
					if(i == -1) putchar(' ');
					printf((decimal ? "Run %5d"  : "Run $%04x"), (seg.object[0] | (seg.object[1] << 8)));
				}
				if(i != -1) {
					if(r != -1) printf(", ");
					printf((decimal ? "Init %5d"  : "Init $%04x"), (seg.object[0] | (seg.object[1] << 8)));
				}

				putchar('\n');
			}

			offset = ftell(f);
		}

		if(xex_errno) {
			fprintf(stderr, SELF ": %s: %s\n",
					filename, xex_strerror(xex_errno));
		}

		if(only_segment && (segcount < only_segment)) {
			fprintf(stderr, SELF ": can't show segment %d, only %d segments in file.\n", only_segment, segcount);
			return 1;
		}

		if(!segcount) {
			fprintf(stderr, SELF ": %s is not an Atari 8-bit executable.\n", filename);
			return 1;
		}
	}

	return 0;
}