#include <nagdmc.h>

/*
  handle_info() prints to screen information based on the value of the info
  parameter.
*/
int
handle_info(const char func[], int info);

/*
  step_through() is the function described in 'Explanatory Code' in 
  nagdmc_reg_tree.pdf.
*/
void
step_through(long bcat[], long ipnode);

int 
main(void) {
    const char    file[] = {"bostonh.dat"};
    const char    sfile[] = {"bostonh_reg_tree_save"};
    long          rec1 = 0;
    long          nvar = 14;
    long          nrec = 506;
    long          dblk = 506;
    double       *data = 0;
    long          nxvar = 0;
    long         *xvar = 0;
    long          yvar = 13;
    long          iwts = -1;
    long          ncat[] = {0,0,0,2,0,0,0,0,0,0,0,0,0,0};
    long          bcat[] = {0,0,0,1,0,0,0,0,0,0,0,0,0,0};
    long          mnc = 7;
    long          mns = 20;
    double        alpha = 0.2;
    long          iproot = 0;
    int           info = 0;
    int           optrand = 0;
    long          iseed = -1;
    double       *res = 0;
    double       *acc = 0;
    FILE         *fp = 0;
    long          i, j;
    

    /*
      Allocate memory for data and read data values.
    */
    if (!(data = (double *)malloc(dblk*nvar * sizeof(double)))) {
        printf (" Memory allocation failure.\n\n");
        return 2;
    }

    if ((fp = fopen(file,"r")) == 0) {
        printf("\n Data file named %s not found. \n\n",file);
        return 2;
    }
    
    for (i=0; i<dblk; ++i) {
        for (j=0; j<nvar; ++j) 
            fscanf(fp,"%lf ",&data[i*nvar+j]);
        
        data[i*nvar+yvar] += 1.0;
    }
    
    fclose(fp);

    /*
      Compute tree lattice.
    */
    nagdmc_reg_tree(rec1,nvar,nrec,dblk,data,nxvar,xvar,yvar,iwts,ncat,bcat,
                    mns,mnc,alpha,&iproot,&info);

    if (handle_info("nagdmc_reg_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Explanatory code from doc.
    */
    step_through(bcat,iproot);

    /*
      Save lattice in file.
    */
    nagdmc_save_reg_tree(iproot,sfile,&info);
    
    if (handle_info("nagdmc_save_reg_tree",info)) {
        free(data);
        if (iproot)
            nagdmc_free_reg_tree(iproot);
        return 2;
    }

    /*
      Free memory containing lattice.
    */
    nagdmc_free_reg_tree(iproot); iproot = 0;

    /*
      Load lattice into memory.
    */
    nagdmc_load_reg_tree(sfile,&iproot,&info);

    if (handle_info ("nagdmc_load_reg_tree",info)) {
        free(data);
        return 2;
    }

    /*
      Compute predictions using the tree lattice.
    */
    if (!(res = (double *)malloc(nrec * sizeof(double))) ||
        !(acc = (double *)malloc(nrec * sizeof(double))))
    {
        printf(" Memory allocation failure.\n\n");
        free(data);
        if (iproot)
            nagdmc_free_reg_tree(iproot);
        return 2;
    }
    
    nagdmc_predict_reg_tree(rec1,nvar,nrec,dblk,data,bcat,iproot,optrand,iseed,
                            res,acc,&info);

    if (handle_info("nagdmc_predict_reg_tree",info)) {
        free(data);
        if (iproot)
            nagdmc_free_reg_tree(iproot);
        free(res);
        free(acc);
        return 2;
    }

    /*
      Print predictions.
    */
    printf("\n Observed\tPredicted\tAccuracy\n\n");
    for (i=0; i<nrec; ++i)
        printf(" %-8g\t%-9g\t%-8.4f\n",data[(i+rec1)*nvar+yvar],res[i],acc[i]);

    /*
      Return allocated memory to operating system.
    */
    free(data);
    if (iproot != 0)
        nagdmc_free_reg_tree(iproot);
    free(res);
    free(acc);

    return 0;
}    

void
step_through(long bcat[], long ipnode) {
    long          i, j;
	RTNode       *lnode = (RTNode *)ipnode;


	if (lnode == 0)
        return;

    printf("\n Node   %8p"
           "\n Parent %8p"
           "\n type:  %8i"
           "\n svar:  %8li"
           "\n sval:  %8.4f"
           "\n rss:   %8.4f"
           "\n ybar:  %8.4f"
           "\n yvar:  %8.4f"
           "\n ndata: %8li",
           lnode,lnode->parent,lnode->type,lnode->svar,lnode->sval,
           lnode->rss,lnode->ybar,lnode->yvar,lnode->ndata);

    j = 0 + (bcat != 0 ? bcat[lnode->svar] : 0);
    
    if (lnode->ncats > 0) {
        printf("\n lr:          ");
        for (i=0; i<lnode->ncats; ++i) {
            if (lnode->lr[i] != 'a')
                printf (" Cat. %li goes %c;",j+i,lnode->lr[i]);
        }
        printf("\n");
    }

    printf("\n");

	step_through(bcat,(long)(lnode->lchild));
	step_through(bcat,(long)(lnode->rchild));
}

int
handle_info(const char func[], int info) {
    if (info == -999)
    {
        printf(" Invalid licence, please contact NAG.\n\n");
        return 2;
    }
    else if (info > 0)
    {
        printf(" Error code %i from %s.\n\n",info,func);
        return 1;
    }
    else if (info < 0)
        printf (" Information code %i from %s.\n\n",info,func);

    return 0;
}


syntax highlighted by Code2HTML, v. 0.8.11