#include <nagdmc.h>

/*
  handle_info() prints to screen information based on the value of the info
  parameter.
*/
int
handle_info(const char func[], int info);

/*
  step_through() is the function described in 'Explanatory Code' in
  nagdmc_gini_tree.pdf
*/
void 
step_through(long bcat[], long ipnode);

int 
main(void) {
    const char    file[] = {"iris.dat"};
    const char    sfile[] = {"iris_gini_save"};
    long          rec1 = 0;
    long          nvar = 5;
    long          nrec = 150;
    long          dblk = 150;
    double       *data = 0;
    long          nxvar = 4;
    long          xvar[] = {0,1,2,3};
    long          yvar = 4;
    long          ncat[] = {0,0,0,0,3};
    long         *bcat = 0;
    double       *prior = 0;
    long          mnc = 7;
    long          mns = 20;
    double        alpha = 0.2;
    long          iproot = 0;
    int           info = 0;
    FILE         *fp = 0;
    long          i, j;
    /*
      Parameters for nagdmc_predict_gini_tree().
    */
    int           optrand = 0;
    long          iseed = -1;
    long         *score = 0;
    double       *rm = 0;
    long          ncaty;

    /*
      Initial values.
    */
    ncaty = ncat[yvar];

    /*
      Allocate memory for data and load data values.
    */
    if (!(data = (double *)malloc(dblk*nvar * sizeof(double)))) {
        printf(" Memory allocation failure.\n\n");
        return 2;
    }
    
    if ((fp = fopen(file,"r")) == 0) {
        printf(" Data file named %s was not found.\n\n",file);
        return 2;
    }
    
    for (i=0; i<dblk; ++i) {
        for (j=0; j<nvar; ++j) 
            fscanf(fp,"%lf ",&data[i*nvar+j]);
    }

    fclose(fp);
 
    /*
      Build tree lattice.
    */
    nagdmc_gini_tree(rec1,nvar,nrec,dblk,data,nxvar,xvar,yvar,ncat,bcat,prior,
                     mns,mnc,alpha,&iproot,&info);

    if (handle_info("nagdmc_gini_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Print information on nodes in lattice.
    */
    step_through(bcat,iproot);

    /*
      Save lattice in file.
    */
    nagdmc_save_gini_tree(ncaty,iproot,sfile,&info);
    
    if (handle_info("nagdmc_save_gini_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Free memory of lattice.
    */
    nagdmc_free_gini_tree(iproot); iproot = 0;

    /*
      Load lattice into memory.
    */
    nagdmc_load_gini_tree(ncaty,sfile,&iproot,&info);

    if (handle_info("nagdmc_load_gini_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Compute predictions.
    */
    if (!(score = (long *)malloc(nrec * sizeof(long)))) {
        printf(" Memory allocation failure.\n\n");
        free(data);
        return 2;
    }
    if (!(rm = (double *)malloc(nrec * sizeof(double)))) {
        printf(" Memory allocation failure.\n\n");
        free(data);
        return 2;
    }

    nagdmc_predict_gini_tree(rec1,nvar,nrec,dblk,data,yvar,bcat,iproot,optrand,
                             iseed,score,rm,&info);

    if (handle_info("nagdmc_predict_gini_tree",info)) {
        free(data);
        free(score);
        free(rm);
        return 2;
    }

    /*
      Print predictions.
    */
    printf ("\n Observed\tPredicted\tRate\n\n");
    for (i=0; i<nrec; ++i)
        printf (" %-8li\t%-9li\t%-4.2g\n",
                (long)data[(i+rec1)*nvar+yvar],score[i],rm[i]);

    if (iproot != 0)
        nagdmc_free_gini_tree(iproot);
    if (score)
        free(score);
    if (rm)
        free(rm);
    if (data)
        free(data);

    return 0;
}

void
step_through(long bcat[], long ipnode) {
    long          i, j;
	CTNode       *lnode;

    lnode = (CTNode *)ipnode;

	if (lnode == 0)
        return;

    printf("\n Node   %8p"
           "\n Parent %8p"
           "\n type:  %8i"
           "\n svar:  %8li"
           "\n sval:  %8.4f"
           "\n giv:   %8.4f"
           "\n imp:   %8.4f"
           "\n yval:  %8li"
           "\n ndata: %8li",
           lnode,lnode->parent,lnode->type,lnode->svar,lnode->sval,
           lnode->giv,lnode->improve,lnode->yval,lnode->ndata);

    j = 0 + (bcat != 0 ? bcat[lnode->svar] : 0);
    
    if (lnode->ncats > 0) {
        printf("\n lr:          ");
        for (i=0; i<lnode->ncats; ++i) {
            if (lnode->lr[i] != 'a')
                printf(" Cat. %li goes %c;",j+i,lnode->lr[i]);
        }
        printf("\n");
    }

    printf("\n");

	step_through(bcat,(long)(lnode->lchild));
	step_through(bcat,(long)(lnode->rchild));
}

int
handle_info(const char func[], int info) {
    if (info == -999)
    {
        printf(" Invalid licence, please contact NAG.\n\n");
        return 2;
    }
    else if (info > 0)
    {
        printf(" Error code %i from %s.\n\n",info,func);
        return 1;
    }
    else if (info < 0)
        printf (" Information code %i from %s.\n\n",info,func);

    return 0;
}


syntax highlighted by Code2HTML, v. 0.8.11