/*
* Generating Rules from Back-Propagation Weights
*
* This program demos how the rules may be extracted for a net with
* several inputs and one output. All the program needs are the weights
* and threshold values. The algorithm is to search the space of
* statements that can be made about the input to the network. For
* example, "i1=1 and i4=0 and i7=1" is a statement of length three
* which defines a class of inputs. The statement space is searched in
* shortest-to-longest and strongest-to-weakest order. The data struct
* representing a statement is an array of integers. The size is also
* passed with the array, and the array is always sorted least to
* greatest.
*/
#include
#include
#include
#include
#include
#include
#include
#define maxwets 5000
#define maxth 1000
#define wmax 1000
FILE *netf, *rulef;
char netfn[80];
char rulefn[80];
float wet[maxwets]; /* network weights are stored here */
float thresh[maxth];
float bar;
int map[wmax];
int vec[wmax];
float wet2[wmax]; /* a sorted abs copy of wet */
int wsiz, limit;
int t; /* used by explore */
float totmin;
int insiz,h1siz,h2siz,outsiz;
int numwets,numlayers,numnodes;
int pass,n,fw; /* used by extract_rules and by print_rule */
char rhs[10];
char lhs[10];
long rnum; /* number of rules generated */
/**************************************************************************/
#define visit
#define print_rule \
++rnum; \
fprintf(rulef,"If"); \
for (i=0; i0) fprintf(rulef," &"); \
if ((wet[map[vec[i]]]>0) ^ (pass==0)) \
fprintf(rulef," not"); \
fprintf(rulef," %s%d",lhs,map[vec[i]]-fw); \
} \
fprintf(rulef," then"); \
if (pass!=0) fprintf(rulef," not"); \
fprintf(rulef," %s%d\n",rhs,n);
/**************************************************************************/
void load_weights()
/* load the dimensions and weights from the neural net file */
{
char vers[11];
int i;
netf = fopen(netfn,"r");
if (netf==NULL) { printf("Error opening input."); exit(0); }
rulef = fopen(rulefn,"w");
if (rulef==NULL) { printf("Error opening output."); exit(0); }
fscanf(netf,"%*s %s",vers);
if (strcmp(vers,"1.2-100989")) { /* wrong version */
printf("Network file incompatible, version '%s'.\n",vers);
exit(0);
}
fscanf(netf,"%*s %*s %*[^\n] %*s %*s %*s %*s");
fscanf(netf,"%d %d %d %d %d",&insiz,&h1siz,&h2siz,&outsiz,&numnodes);
printf("Input size %d\n",insiz);
printf("Output size %d\n",outsiz);
for (i=0; i<28; i++) /* skip extra stuff */
fscanf(netf,"%*s");
for (i=0; i maxwets) {
printf("Network is too large, must have less than %d weights\n",maxwets);
exit(0);
}
for (i=0; i bar) {
print_rule;
return 1;
}
return 0;
}
void explore(int *vec, int num)
{
int *myvec;
int i;
t = num*sizeof(int); /* compute size of vec in bytes */
myvec = malloc(t+sizeof(int));
memcpy(myvec,vec,t); /* make a copy of it */
while (1)
if (succeed(myvec,num)) { /* statement succeeded */
for (i=num-2; i>=0; i--) {
t = myvec[i+1] - myvec[i];
if (t > 2) break; /* if t > 2 then no new path */
if (t == 2) { /* if t==2 then explore new path */
++myvec[i];
explore(myvec,num);
--myvec[i];
break; /* for */
} /* if t==1 then keep looking */
}
if (++myvec[num-1] >= wsiz) /* keep exploring regular path */
break; /* while */
}
else { /* statement failed */
if (num < limit) {
if ((myvec[num]=myvec[num-1]+1) < wsiz)
explore(myvec,num+1);
}
break;
} /* end of while-loop */
free(myvec);
}
/**************************************************************************/
int compare(int *i,int *j)
{
return (fabs(wet[*i]) < fabs(wet[*j])) ? 1 : -1;
}
/**************************************************************************/
void extract_rules(int firstnode,int numnode,int firstwet,int numwet)
/* this takes the weight array of one layer of nodes and generates rules */
{
int i;
wsiz = numwet;
fw = firstwet;
limit = wsiz / 5;
if (limit < 5) limit = 5; /* set limit to wsiz/5 but not less than 5 */
for (n=0; n bar)
fprintf(rulef,"If TRUE then %s%d\n",rhs,n);
else {
vec[0] = 0;
explore(vec,1);
}
pass = 1; /* second pass: prove low outputs */
for (totmin=0,i=0; i0) totmin-=wet[i+fw];
totmin -= thresh[n+firstnode];
if (totmin > bar)
fprintf(rulef,"If TRUE then not %s%d\n",rhs,n);
else {
vec[0] = 0;
explore(vec,1);
}
fw += wsiz;
}
}
main(int argc, char *argv[])
{
if (argc == 4) {
strcpy(netfn,argv[1]); printf("Input file : %s\n",netfn);
strcpy(rulefn,argv[2]); printf("Output file : %s\n",rulefn);
sscanf(argv[3],"%f",&bar); printf("Bar is : %1.2f\n",bar);
}
else {
printf("Enter input neural net filename : "); scanf("%s",netfn);
printf("Enter output rule filename : "); scanf("%s",rulefn);
printf("Enter bar value : "); scanf("%f",&bar);
}
load_weights();
rnum = 0;
switch (numlayers) {
case 1:
strcpy(rhs,"OUT");
strcpy(lhs,"IN");
extract_rules(insiz,outsiz,0,insiz);
break;
case 2:
strcpy(rhs,"HID");
strcpy(lhs,"IN");
extract_rules(insiz,h1siz,0,insiz);
strcpy(rhs,"OUT");
strcpy(lhs,"HID");
extract_rules(insiz+h1siz,outsiz,insiz*h1siz,h1siz);
break;
case 3:
printf("I can't do three layer nets yet");
}
printf("Finished.\n%ld rules generated.",rnum);
}
               (
geocities.com/Paris)