package jcm.mod.math;

import jcm.core.cur.curve;
import static jcm.core.report.*;


public class optimiser {
    
    //inner interface - beware may not work with earliest versions of java
    public interface optimcaller {
	//		void setparams(float[] vals);
	float getobjfunc();
    }
    
    int nd, np, cp,  ni, ni2; //number of dimensions of problem, number points, current point, highest/lowest point, number iterations;
    float[][] sx; // #no of point, dimension (sx[][nd] is value of objective function)
    float[] cen, hip, lop, rfp, newp, initguess, perturb, max, min;
    float of, tol, tol2, nsd;
    
    optimcaller oc;
    public String report;
    public int step=0;
    static int reset=0, fill=1, reflect=2, compare=3, checkcontract=4, checkexpand=5, checkopt=6;
    static String[] stepname={	"reset", "fill", "reflect", "compare", "chkcon", "chkexp", "chkopt"};
    public boolean goodenough=false, oktocheckvar=false;
    static public boolean showreport=true;
    
    //constructor
    public optimiser(optimcaller oc, float[] initguess, float[] perturb, float[] min, float[] max  ) {
	this.oc=oc; this.initguess=initguess; this.perturb=perturb;
	this.max=max; this.min=min;
	nd=initguess.length; np=nd+1;
	step=reset;
	tol=0.0001f; tol2=0.05f; //tolerance: max deviation in objective function, frac of initial perturbation to check optimality
	sx= new float[np][np];
	rfp=new  float[np]; newp=new float[np];
	cen=new float[nd];
    }
    
    
    
    public void reset() {	ni=0; ni2=0; step=reset; initguess=cen; goodenough=false; }
    
    public float[] nextpoint() {
	if (step==reset) {	cp=0; init(initguess, 1f); step=fill; nsd=curve.dud;  } else getobjfunc();
	if (showreport || (ni%50)==0 )	report(); 
	ni++;
	
	if (step==fill) {
	    if (cp<np) {	newp=sx[cp]; cp++; return newp; } else step=reflect;
	}
	if (step==checkopt) {
	    if (cp<np) {	newp=sx[cp]; cp++; return newp; }
	    findhighlow();
	    if (lop==sx[0]) {	goodenough=true; report+=" end-opt "; report(); return lop; }
	    ni2=0; step=reflect;
	}
	
	if (step==checkexpand) {
	    if (of<rfp[nd]) copy(newp, sx[cp]);
	    else copy(rfp, sx[cp]);
	    step=reflect;
	}
	if (step==checkcontract) {
	    if (of<rfp[nd]) {	copy(newp, sx[cp]); step=reflect; } else {	cp=0; totalcon(); step=fill; oktocheckvar=true; return  nextpoint(); }
	}
	if (step==compare) {
	    copy(newp, rfp);
	    if (of<lop[nd]) {	expcon(2.0f); step=checkexpand; return newp; } //lowest
	    if (of>=hip[nd]) {	copy(hip, rfp); expcon(-0.5f); step=checkcontract; oktocheckvar=true; return newp; } //highest
	    if (sechigh()) {	expcon(0.5f); step=checkcontract; return newp; } //2nd highest
	    else {	copy(newp, sx[cp]); step=reflect; } // in between
	}
	if (step==reflect) {
	    if (oktocheckvar) {
		oktocheckvar=false;
		if (/*ni2>5 && */checkvar()) {	findhighlow(); cp=0; init(lop, tol2); step=checkopt; return nextpoint(); } else {	ni2++; step=reflect; }
	    }
	    findhighlow(); findcentre(true);
	    expcon(1f); step=compare; return newp;
	}
	return null;
    }
    
    void copy(float[] a, float[] b) {	for (int d=0; d<np; d++) b[d]=a[d]; }
    
    void getobjfunc() {
	of=oc.getobjfunc();
	if (Float.isNaN(of)) of=Float.MAX_VALUE;
	for (int d=0; d<nd; d++) if (newp[d] > max[d] || newp[d] < min[d]) {	of=Float.MAX_VALUE; report +="hit limits! "; }
	newp[nd]=of;
    }
    
    void report() {
	report="it="+ni+"\t "+stepname[step]+"\t ";
	for (int d=0; d<nd+1; d++) report+=((int)(newp[d]*1000f)/1000f)+"\t ";
	log(report);
    }
    
    void findhighlow() {
	hip=sx[0]; lop=sx[0]; cp=0;
	for (int p=1; p<np; p++) {
	    if (sx[p][nd]>hip[nd]) {	hip=sx[p]; cp=p; }
	    if (sx[p][nd]<lop[nd]) {	lop=sx[p]; }
	}
    }
    
    boolean sechigh() {	for (int p=0; p<np; p++) if (sx[p][nd]>newp[nd] && sx[p]!=hip) return false; return true; }
    
    void findcentre(boolean exchigh) {
	for (int d=0; d<nd; d++) cen[d]=0;
	for (int p=0; p<np; p++)  if (!(exchigh && sx[p]==hip)) for (int d=0; d<nd; d++) cen[d]+=sx[p][d];
	for (int d=0; d<nd; d++) cen[d]/=(exchigh ? (np-1) : np);
    }
    
    void expcon(float fac) {	newp=new float[np]; for (int d=0; d<nd; d++) newp[d]=(1f+fac) *cen[d] - fac * hip[d]; }
    
    void totalcon() {	for (int p=0; p<np; p++) for (int d=0; d<nd; d++)  sx[p][d]=(sx[p][d]+lop[d])/2; }
    
    void init( float[] guess, float fac) {	for (int p=0; p<np; p++) 	for (int d=0; d<nd; d++) sx[p][d]= guess[d]+ (p==d+1 ? perturb[d]*fac : 0 ); }
    
    boolean checkvar() {
	double sum=0, sum2=0; for (int p=0; p<np; p++) {	sum+=sx[p][nd]; sum2+=sx[p][nd]*sx[p][nd]; }
	double mean= sum/np, var = (sum2/np) - (mean*mean) , sd=	Math.pow( var , 0.5) ;
	nsd=(float) (sd / mean);
	if (Float.isNaN(nsd) && showreport) log("optstatprob: "+mean+" "+ sum+" "+ sum2 + " "+var+" "+nsd);
	return ( ni>1000 || (ni>np+5 && ( nsd<tol && nsd > -tol)) );
    }
    
} //end class

//try {	System.in.read(); } catch (Exception e) {	}

