/** \file training a neural network it uses Fast artificial neural network library (http://fann.sourceforge.net)... load existing network, add new batch, outputs the learn-error, saves the network */ #include #include #include #include #include "local.h" /* Prototypes */ static int help(char *argv[]); static int calcoptions(int argc, char * argv[]); static struct fann * create_network(); static void save_network(struct fann * ann); static void destroy_network(struct fann * ann); static void read_datasets(); int main(int argc, char * argv[]); /** calculates and checks the given options and sets the struct net */ static int calcoptions(int argc, char* argv[]) { int opt_case; optind=1; opt_case=0; /* defaults for network parameters */ net.connection_rate=1.0f; /* full connected MLP network */ net.learning_rate=0.07f; /* 0.07f ;*/ net.num_layers=3; /* one input, one hidden, one output layer */ net.num_input=30; /* count of input neurons */ net.num_neurons_hidden=30; /* how many hidden neurons, 3 is a good choice */ net.num_output=1; net.desired_error=0.0001f; /* training is aborted if learning error is less than... */ net.max_steps=10000; /* how many training steps */ net.steps_between_reports=1; /* how often reports will be generated */ net.net_infile=NULL; net.outfilenames=NULL; net.net_outfile=NULL; net.verbose=0; net.net_writeperiod=net.steps_between_reports; net.infilename=NULL; net.trainfilename=NULL; net.testfilename=NULL; while (TRUE) { /* true */ opt_case=getopt(argc,argv,"c:e:hi:l:m:o:r:v"); if (opt_case == -1 ) {break;} switch (opt_case) { case 'e': net.desired_error=(float) atof(optarg); if(net.desired_error<=0) { fprintf(stderr, "learning error should be greater than zero\n"); (void) help(argv); exit(ERRARG); } break; case 'i': net.net_infile=optarg; break; case 'l': net.learning_rate=(float) atof(optarg); if(net.learning_rate<=0) { fprintf(stderr, "learning rate should be greater than zero\n"); (void) help(argv); exit(ERRARG); } break; case 'm': net.max_steps=(unsigned int) atoi(optarg); if(net.max_steps<1) { fprintf(stderr, "max_steps should be greater than zero\n"); (void) help(argv); exit(ERRARG); } break; case 'o': net.outfilenames=optarg; net.net_outfile=malloc((strlen(net.outfilenames)+5)*sizeof(char)); if (NULL==net.net_outfile) { fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno)); exit(ERRMEM); } strcpy(net.net_outfile, net.outfilenames); strcat(net.net_outfile, ".net"); net.net_mqlefile=malloc((strlen(net.outfilenames)+6)*sizeof(char)); if (NULL==net.net_mqlefile) { fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno)); exit(ERRMEM); } strcpy(net.net_mqlefile, net.outfilenames); strcat(net.net_mqlefile, ".mqle"); net.net_mqgefile=malloc((strlen(net.outfilenames)+6)*sizeof(char)); if (NULL==net.net_mqgefile) { fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno)); exit(ERRMEM); } strcpy(net.net_mqgefile, net.outfilenames); strcat(net.net_mqgefile, ".mqge"); net.net_cfgfile=malloc((strlen(net.outfilenames)+5)*sizeof(char)); if (NULL==net.net_cfgfile) { fprintf(stderr, "could not alloc mem for buf, %s\n", strerror(errno)); exit(ERRMEM); } strcpy(net.net_cfgfile, net.outfilenames); strcat(net.net_cfgfile, ".cfg"); break; case 'r': net.steps_between_reports=(unsigned int) atoi(optarg); if(net.steps_between_reports<1) { fprintf(stderr, "steps_between_reports should be greater than zero\n"); (void) help(argv); exit(ERRARG); } break; case 'v': net.verbose=1; break; case '?': fprintf(stderr, "there was an unknown parameter: %i \n", optopt); (void) help(argv); exit(0); case ':': fprintf(stderr, "there was a required parameter missed: %s \n", optarg); (void) help(argv); exit(0); case 'h': (void) help(argv); exit(0); default: /*printf("arg was a option: %s\n",optarg);*/ break; } } if ((argc-optind)!=2) { /* <1 means too few args */ fprintf(stderr, "wrong count of arguments, / missed\n"); (void) help(argv); exit(0); } else { net.trainfilename = argv[optind]; net.testfilename = argv[optind+1]; } return (0); } /** prints the help-page if there were errors on the command-line or if explicitely ordered */ static int help(char * argv[]) { fprintf(stderr,"train net -o \n"); fprintf(stderr," reads fann DAT-files 'train' and 'test', and trains a\n"); fprintf(stderr," neural network. The network will be saved,\n"); fprintf(stderr," if max_steps will be reached or actual\n"); fprintf(stderr," error below desired_error.\n"); fprintf(stderr,"call\n %s [options] \n", argv[0]); fprintf(stderr,"or\n %s -h\n",argv[0]); fprintf(stderr,"options are:\n"); fprintf(stderr, "\t-a\tadaptr learnrate, means per epoch the learningrate decrease (default:%u)\n", net.adapt_learnrate); fprintf(stderr, "\t-e VAL\tlearning error over an epoch (float, default: %03f)\n", net.desired_error); fprintf(stderr, "\t-h \thelp\n"); fprintf(stderr, "\t-i VAL\tnet file to load (string, default: %s)\n", net.net_infile); fprintf(stderr, "\t-l VAL\tlearning_rate (float, default:%0.10f)\n", net.learning_rate); fprintf(stderr, "\t-m VAL\tmax_steps (unsigned int, default: %u)\n", net.max_steps); fprintf(stderr, "\t-o VAL\toutfiles (outfile names without extensions) to store (string, default:%s)\n",net.outfilenames); fprintf(stderr, "\t-r VAL\tsteps_between_reports (unsigned int, default: %u)\n", net.steps_between_reports); fprintf(stderr, "\t-v\t verbose (default: %u)\n", net.verbose); fprintf(stderr, "\n"); return(0); } /** small function to encapsulate the creation/loading of ANN */ static struct fann * create_network() { struct fann * ann; /* create or load neural network */ if (NULL != net.net_infile) { /* load network-data from file */ /* TODO */ ann = fann_create_from_file(net.net_infile); if (ann == NULL) { fprintf(stderr, "could not create network from file %s: %s\n", net.net_infile, strerror(errno)); exit(ERRFANN); } } else { /* build network from data in net-structure */ ann = fann_create(net.connection_rate, net.learning_rate, net.num_layers,net.num_input, net.num_neurons_hidden, net.num_output); if (ann == NULL) { fprintf(stderr, "could not create network: %s\n", strerror(errno)); exit(ERRFANN); } fann_randomize_weights(ann, -0.9f, 0.9f); /* only with fann > 1.2, speeds up learning: fann_set_momentum(ann, 0.9); */ fann_set_activation_function_output(ann, FANN_SIGMOID_SYMMETRIC); fann_set_activation_function_hidden(ann, FANN_SIGMOID_SYMMETRIC); } return(ann); } /** small function to encapsulate the peridocal backup mechanism to store the actual trained ANN */ static void save_network(struct fann * ann) { assert(ann !=NULL); /* save network */ if (net.net_outfile!=NULL) { /* only if filename specified... */ fann_save(ann,net.net_outfile); } } /** small function to free the ANN structure */ static void destroy_network(struct fann * ann) { assert(ann!=NULL); fann_destroy(ann); } /** readin data into training and test set */ static void read_datasets () { net.train_data = fann_read_train_from_file(net.trainfilename); net.test_data = fann_read_train_from_file(net.testfilename); fann_shuffle_train_data(net.train_data); } /** main logic to prepare ANN and feed it with the right data to train an ANN */ int main(int argc, char * argv[]) { FILE * fh_mqle; FILE * fh_mqge; struct fann * ann; float mqle; float mqge; unsigned int i,j; struct timeval lasttimestamp; struct timeval timestamp; gettimeofday(×tamp,NULL); lasttimestamp.tv_sec=timestamp.tv_sec; lasttimestamp.tv_usec=timestamp.tv_usec; #ifndef FLOATFANN /* assertion that we use float as fann_type */ #error "to avoid type conflicts, the float-variant of libfann must be used" #endif (void) calcoptions(argc, argv); if (net.outfilenames!=NULL) { fh_mqle=fopen(net.net_mqlefile, "w"); if (fh_mqle==NULL) { fprintf(stderr, "File %s could not be open, %s (line %i)\n", net.net_mqlefile, strerror(errno), __LINE__); exit(ERRSTREAM); } fprintf(fh_mqle, "# sample, fh_mqle over %u steps, time in s\n", net.steps_between_reports); fh_mqge=fopen(net.net_mqgefile, "w"); if (fh_mqge==NULL) { fprintf(stderr, "File %s could not be open, %s (line %i)\n", net.net_mqgefile, strerror(errno), __LINE__); exit(ERRSTREAM); } fprintf(fh_mqge, "# sample, fh_mqge over %i steps\n", net.max_steps); } else { fh_mqge=stdout; fh_mqle=stderr; } read_datasets(); ann=create_network(); /* train network */ fprintf(stderr, "total neurons: %u, total connections: %u\n", fann_get_total_neurons(ann), fann_get_total_connections(ann)); fprintf(stderr, "start with training\n"); (void) fflush(stderr); mqle=10000.0f; mqge=10000.0f; /* for every epoch */ for (i=0; i< net.max_steps && mqge > net.desired_error; i++) { int rindex=i % net.train_data->num_data; fann_reset_MSE(ann); /* reset mean learning error */ fann_train(ann, net.train_data->input[rindex],net.train_data->output[rindex]); mqle=fann_get_MSE(ann); if (0 == i%net.steps_between_reports) { /* calc MQGE */ fann_reset_MSE(ann); /* reset mean learning error */ for (j=0; j num_data; j++) { fann_test(ann, net.test_data->input[j], net.test_data->output[j]); } mqge=fann_get_MSE(ann); if (net.outfilenames!=NULL) { gettimeofday(×tamp,NULL); fprintf(fh_mqle, "%u %0.5f %0.5f\n", i, mqle, (float) (timestamp.tv_sec-lasttimestamp.tv_sec)+((float) (timestamp.tv_usec-lasttimestamp.tv_usec)/1000)); fprintf(fh_mqge, "%u %0.5f\n", i, mqge); lasttimestamp.tv_sec=timestamp.tv_sec; lasttimestamp.tv_usec=timestamp.tv_usec; /* to trace learning with gnuplot, per epoch */ (void) fflush(fh_mqle); (void) fflush(fh_mqge); } printf ("step:%i MQLE:%0.5f MQGE:%0.5f\n", i, mqle, mqge); } if ( /* save network periodically, only if filename specified... */ (net.net_outfile!=NULL) && (0 == i%net.net_writeperiod) ) { save_network(ann); } (void) fflush(stderr); } /* end for loop, end of training */ save_network(ann); destroy_network(ann); if (net.outfilenames!=NULL) { fclose(fh_mqle); fclose(fh_mqge); } return(0); }