#include <nagdmc.h>

/*
  handle_info() prints to screen information based on the value of the info
  parameter.
*/
int
handle_info(const char func[], int info);

/*
  step_through() is the function as described in the function document 
  nagdmc_entropy_tree.pdf
*/
void
step_through(long node);

int
main(void) {
    const char    file[] = {"wisc.dat"};
    const char    savefile[] = {"entropy_tree_save"};
    long          rec1 = 0;
    long          nvar = 10;
    long          nrec = 683;
    long          dblk = 683;
    double       *data = 0;
    long          nxvar = 0;
    long         *xvar = 0;
    long          yvar = 9;
    long          ncat[] = {0,0,0,0,0,0,0,0,0,2};
    long          cat[] = {1,2};
    long          mnc = 10;
    long          iproot = 0;
    long         *res = 0;
    double       *prob = 0;
    int           info = 0;    
    int           optrand = 0;
    long          iseed = -1;
    FILE         *fp = 0;
    long          i, j, correct = 0;
    
    /*
      Allocate memory for data and load data values.
    */
    if (!(data = (double *)malloc(dblk*nvar * sizeof(double)))) {
        printf(" Memory allocation failure.\n\n");
        return 2;
    }
    
    if ((fp = fopen(file,"r")) == 0) {
        printf(" Data file named %s was not found.\n\n",file);
        return 2;
    }
    
    for (i=0; i<dblk; ++i) {
        for (j=0; j<nvar; ++j) 
            fscanf(fp,"%lf ",&data[i*nvar+j]);
    }

    fclose(fp);
    
    /*
      Compute classification tree by using entropy criterion.
    */
    nagdmc_entropy_tree(rec1,nvar,nrec,dblk,data,nxvar,xvar,yvar,ncat,cat,mnc,
                        &iproot,&info);
    if (handle_info("nagdmc_entropy_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Pessimistic error pruning of tree lattice.
    */
	nagdmc_prune_entropy_tree(iproot);
    if (handle_info("nagdmc_prune_entropy_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Explanatory code from function doc.
    */
    step_through(iproot);

    /*
      The example of saving and re-loading the tree lattice now follows.
    */
    
    /*
      Save tree.
    */
    nagdmc_save_entropy_tree(iproot,savefile,&info);
    if (handle_info("nagdmc_save_entropy_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Free memory containing the tree.
    */
    nagdmc_free_entropy_tree(iproot);

    /*
      Re-set cast of memory location of root node.
    */
    iproot = 0;

    /*
      Load the tree into memory from the saved file.
    */
    nagdmc_load_entropy_tree(savefile,&iproot,&info);
    if (handle_info("nagdmc_load_entropy_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Calculate predictions for training data using the tree.
    */
    /*
      Allocate memory for return arrays.
    */
    if (!(res = (long *)malloc(nrec * sizeof(long)))) {
        printf(" Memory allocation failure.\n\n");
        return 2;
    }
        
    if (!(prob = (double *)malloc(nrec * sizeof(double)))) {
        printf(" Memory allocation failure.\n\n");
        return 2;
    }

    nagdmc_predict_entropy_tree(rec1,nvar,nrec,dblk,data,iproot,optrand,iseed,
                                res,prob,&info);
    if (handle_info("nagdmc_predict_entropy_tree",info)) {
        free(data);
        free(res);
        free(prob);
        return 2;
    }
    
    printf("\n\n Decision tree classifications.\n"
            "------------------------------");
    printf("\n\n Datum\tClassification\n");
    
    for (i=0; i<nrec; ++i) {
        printf(" %3li\t%8li\t(%4.3f)\n",i,res[i],prob[i]);
        
        if (res[i] == (long)data[i*nvar+yvar])
            correct++;
    }
    
    printf("\n Number of correct classifications: %li.\n\n",correct);

    if (iproot)
        nagdmc_free_entropy_tree(iproot);
    free(data);
    free(res);
    free(prob);

    return 0;
}

void
step_through(long node) {
    long i;
    ENode *lnode;

    lnode = (ENode *)node;

    if (lnode == 0)
        return;

    printf("\n Node %8p"
           "\n Parent %8p"
           "\n No. data at node %8li"
           "\n Modal class %li",
           lnode,lnode->parent,lnode->ndata,lnode->modal_class);

    printf("\n Class distribution at node:");

    for (i=0; i<lnode->nclasses; ++i)
        printf(" %li",lnode->ninclasses[i]);

    if (lnode->index > 0) {
        printf("\n Partition on index %li",lnode->index);
        
        if (lnode->discrete == 0)
            printf("; continuous test value: %8.4f",lnode->children[0]->value);

        printf("\n");

        for (i=0; i<lnode->nchildren; ++i)
            step_through((long)(lnode->children[i]));
    }
    else
        printf("\n Leaf node\n");
}

int
handle_info(const char func[], int info) {
    if (info == -999)
    {
        printf(" Invalid licence, please contact NAG.\n\n");
        return 2;
    }
    else if (info > 0)
    {
        printf(" Error code %i from %s.\n\n",info,func);
        return 1;
    }
    else if (info < 0)
        printf (" Information code %i from %s.\n\n",info,func);

    return 0;
}


syntax highlighted by Code2HTML, v. 0.8.11