/* optimized matrix multiply */
/* Thomas Zetty, March 13, 2004 */

/* 
ASSUMPTIONS:
1. matrices can be composed many numeric types. int, long, float and double
   were tested. see comments for a how-to guide
2. only limitation of matrix size is LONG_LINE_SIZE when entering row data

NOTES:
1. the basic multiplication algorithm achieves some speed from requiring
   the second multiplicand to be column-ordered (the first,and the product
   are in typical row-order. Converting a matrix from row-to-column order is
   a trival transformation of O(n^2) complexity
2. additional speed is gained by nesting Duff's device (cited below) to
   unroll loops. This greatly depends on the size of the matricies as well
   as the size of the Duffs used as both compete for cache space.
   TEST several datasets before choosing!
3. average time for 500 million int-type matrix 4x2 x 2x3 multiplies
   on a 1.3GHz P3 was 14 seconds. This time does not include function-call
   overhead
4. the algorithm is still O(n^3), no attempt was made to optimize the
   quantity of multiplies/adds
5. the ?: operator is sparsely used, hopefully without loss of clarity
6. no games were played with function calling procedure or compiler flags

SAMPLE RUN:
Entering Matrix A
Enter number of rows: 4
Enter number of columns: 2
Enter 2 elements for Row 0 :1 0
Enter 2 elements for Row 1 :-2 3
Enter 2 elements for Row 2 :5 4
Enter 2 elements for Row 3 :0 1
Entering Matrix B
Enter number of rows: 2
Enter number of columns: 3
Enter 3 elements for Row 0 :0 6 1
Enter 3 elements for Row 1 :3 8 -2
Marix A
[  1  0 ]
[ -2  3 ]
[  5  4 ]
[  0  1 ]

Matrix B
[  0  6  1 ]
[  3  8 -2 ]

Product Matrix AB
[  0  6  1 ]
[  9 12 -8 ]
[ 12 62 -3 ]
[  3  8 -2 ]

*/

#include <stdio.h> 
#include <stdlib.h> 

#define LONG_LINE_SIZE 8000 
#define SHORT_LINE_SIZE 80 

typedef enum
{
	ROW_ORDER = 0,
	COL_ORDER
} MAT_ORDER_T;

typedef int ELEMENT_T;
/* ELEMENT_SCAN_STRING and ELEMENT_PRINT_STRING are used */
/* in scanf and printf and should be appropriate for ELEMENT_T */
/* ELEMENT_PRINT_STRING should not have a leading %{width} */
/* ELEMENT_MIN_WIDTH and ELEMENT_WIDTH_ADD should reflect */
/* ELEMENT_PRINT_STRING; for example: */
/* if ELEMENT_PRINT_STRING is .2f
/* ELEMENT_MIN_WIDTH should be at least 4 and ELEMENT_WIDTH_ADD 3 */
/* (this would be a bit more straightforward with C++ iostreams =) */
#define ELEMENT_SCAN_STRING "%d" 
#define ELEMENT_PRINT_STRING "d" 
#define ELEMENT_MIN_WIDTH 1 
#define ELEMENT_WIDTH_ADD 0 

/* reasonable choices for float or double ELEMENT_T */
/*
#define ELEMENT_SCAN_STRING "%f"
#define ELEMENT_PRINT_STRING ".1f"
#define ELEMENT_MIN_WIDTH 3
#define ELEMENT_WIDTH_ADD 2
*/

/* reasonable choices for int ELEMENT_T */
/*
#define ELEMENT_SCAN_STRING "%d"
#define ELEMENT_PRINT_STRING "d"
#define ELEMENT_MIN_WIDTH 1
#define ELEMENT_WIDTH_ADD 0
*/

typedef struct {
	unsigned rows;
	unsigned cols;
	MAT_ORDER_T order;
	const char * name;
	ELEMENT_T * data;
} MAT_T;

void matrix_init(MAT_T * m)
{
	m->rows = 0;
	m->cols = 0;
	m->order = ROW_ORDER;
	m->name = NULL;
	m->data = NULL;
}

void matrix_alloc(MAT_T * m)
{
	/* if there is any data space allocated, assume it is large enough */
	if (m->data == NULL) {
		m->data = malloc(sizeof(ELEMENT_T) * m->rows * m->cols);
		if (m->data == NULL) {
			printf("Memory Allocation Error!\n");
			exit(0);
		}
	}
}

void matrix_free(MAT_T * m)
{
	if (m->data != NULL) free(m->data);
	m->data = NULL;
}

/* Duff's device is documented at: */
/* http://www.lysator.liu.se/c/duffs-device.html */

#define ZDUFF \ 
switch (z % 8) {\
	case 0:	do {	*c += *ta++ * *tb++;\
	case 7:			*c += *ta++ * *tb++;\
	case 6:			*c += *ta++ * *tb++;\
	case 5:			*c += *ta++ * *tb++;\
	case 4:			*c += *ta++ * *tb++;\
	case 3:			*c += *ta++ * *tb++;\
	case 2:			*c += *ta++ * *tb++;\
	case 1:			*c += *ta++ * *tb++;\
				} while ((z-=8) > 0);\
}

#define YDUFF \ 
switch (y % 8) {\
	case 0:	do {	ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 7:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 6:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 5:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 4:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 3:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 2:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
	case 1:			ta=a; *c=0; z=ma->cols; ZDUFF; ++c;\
				} while ((y-=8) > 0);\
}

#define XDUFF \ 
switch (x % 4) {\
	case 0:	do {	tb=b; y=mb->cols; YDUFF; a+=ma->cols;\
	case 3:			tb=b; y=mb->cols; YDUFF; a+=ma->cols;\
	case 2:			tb=b; y=mb->cols; YDUFF; a+=ma->cols;\
	case 1:			tb=b; y=mb->cols; YDUFF; a+=ma->cols;\
				} while ((x-=4) > 0);\
}

/* assumes multiplication is defined on ma and mb */
/* also assumed ma is row-ordered and mb is column-ordered */
void matrix_multiply(const MAT_T * ma, const MAT_T * mb,
					 MAT_T * mc)
{
	const ELEMENT_T * a = ma->data, * ta;
	const ELEMENT_T * b = mb->data, * tb;
	ELEMENT_T * c;
	int x, y, z;

	mc->rows = ma->rows;
	mc->cols = mb->cols;
	matrix_alloc(mc);

	c = mc->data;
	x = ma->rows;
	XDUFF;
}

/*
void old_multiply(const MAT_T * ma, const MAT_T * mb,
						MAT_T * mc)
{
	const ELEMENT_T * a = ma->data, * ta;
	const ELEMENT_T * b = mb->data, * tb;
	ELEMENT_T * c;
	unsigned x, y, z;

	mc->rows = ma->rows;
	mc->cols = mb->cols;
	matrix_alloc(mc);

	c = mc->data;
	for (x=0; x<ma->rows; ++x, a+=ma->cols) {
		for (y=0, tb=b; y<mb->cols; ++y, ++c) {
			for (z=0, *c=0, ta=a; z<ma->cols; ++z) {
				*c += *ta++ * *tb++;
			}
		}
	}
}
*/

void matrix_in(MAT_T * m, const char * name, unsigned order)
{
	unsigned z;
	unsigned datacount;
	char buf[LONG_LINE_SIZE+1];
	char format[SHORT_LINE_SIZE];
	char * start;
	int pos;
	int lindex;

	/* initialize format string for reading in row data */
	sprintf(format, "%s%%n", ELEMENT_SCAN_STRING);

	matrix_init(m);
	m->name = name;
	m->order = order;

	printf("Entering %s\n", m->name);

	printf("Enter number of rows: ");
	sscanf(fgets(buf, LONG_LINE_SIZE, stdin), "%u", &m->rows);
	if (m->rows < 1) {
		printf("Invalid number of rows!");
		exit(0);
	}
	printf("Enter number of columns: ");
	sscanf(fgets(buf, LONG_LINE_SIZE, stdin), "%u", &m->cols);
	if (m->cols < 1) {
		printf("Invalid number of columns!");
		exit(0);
	}

	matrix_alloc(m);

	for (z=0; z<m->rows; ++z) {
		printf("Enter %u elements for Row %u :", m->cols, z);
		fgets(buf, LONG_LINE_SIZE, stdin);

		/* scan buf, extracting one element each pass */
		start = buf;
		lindex = 0;
		datacount = 0;
		while (datacount != m->cols) {
			/* note %n does not count as an extracted element */
			if (sscanf(start, format,
					&m->data[ (m->order==ROW_ORDER) ?
						(z*m->cols+lindex) : (z+m->rows*lindex)],
					&pos) != 1) {
				printf("Invalid row data!");
				exit(0);
			}
			else {
				++lindex;
				start += pos;
				++datacount;
			}
		}
	}
}

/* we need to be careful to check a range around zero */
/* in case ELEMENT_T is other than a scalar */
unsigned width(ELEMENT_T x)
{
	unsigned w = ELEMENT_WIDTH_ADD;
	if (x == 0) return 1;
	if (x < 0) {
		x *= -1;
		++w;
	}
	while ((x > 1) || (x < -1)) {
		x /= 10;
		++w;
	}
	if (w < ELEMENT_MIN_WIDTH) w = ELEMENT_MIN_WIDTH;
	return w;
}

void matrix_out(const MAT_T * m)
{
	char format[SHORT_LINE_SIZE];
	unsigned x, y, z, w, maxwidth=1;

	for (z=0; z<(m->rows * m->cols); ++z) {
		w = width(m->data[z]);
		if (w > maxwidth) maxwidth = w;
	}
	sprintf(format, "%%%d%s ", maxwidth, ELEMENT_PRINT_STRING);

	if (m->name != NULL) printf("%s\n", m->name);
	for (y=0; y<m->rows; ++y) {
		printf("[ ");
		for (x=0; x<m->cols; ++x) {
			printf(format,
				m->data[ (m->order==ROW_ORDER) ?
						(y*m->cols+x) : (y+m->rows*x)]
				);
		}
		printf("]\n");
	}
	printf("\n");
}

int main(void)
{
	MAT_T a, b, c;

	matrix_in(&a, "Matrix A", ROW_ORDER);
	matrix_in(&b, "Matrix B", COL_ORDER);
	if (a.cols != b.rows) {
		printf("Invalid Matrix sizes. Multiplication is not defined!\n");
		exit(0);
	}
	matrix_init(&c);
	c.name = "Product Matrix AB";

	matrix_out(&a);
	matrix_out(&b);
	matrix_multiply(&a, &b, &c);
	matrix_out(&c);

	matrix_free(&a);
	matrix_free(&b);
	matrix_free(&c);

	return 1;
}