tlinesearchHZ.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
 (HTM) git clone git://src.adamsgaard.dk/pism
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) LICENSE
       ---
       tlinesearchHZ.py (11504B)
       ---
            1 ############################################################################
            2 #
            3 #  This file is a part of siple.
            4 #
            5 #  Copyright 2010, 2014 David Maxwell
            6 #
            7 #  siple is free software: you can redistribute it and/or modify
            8 #  it under the terms of the GNU General Public License as published by
            9 #  the Free Software Foundation, either version 2 of the License, or
           10 #  (at your option) any later version.
           11 # 
           12 ############################################################################
           13 
           14 # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
           15 # %%
           16 # %%  linesearchHZ
           17 # %%
           18 # %%  Finds an approximate minimizer of F(t) for t in [0,infinity) satisfying Wolfe conditions.
           19 # %%
           20 # %%  Algorithm from: Hager, W. and Zhang, H. CG DESCENT, a Conjugate Gradient Method with Guaranteed Descent
           21 # %%                  Algorithm 851. ACM Transactions on Mathematical Software. 2006;32(1):113-137.
           22 # %%
           23 # %%  Usage:  status = linesearchHZ(F0, F0p, F, t, params)
           24 # %%
           25 # %%  In: 
           26 # %%  F0      - F(0)
           27 # %%  F0p     - F'(0)
           28 # %%  F       - Function to minimize.  See requirements below.
           29 # %%  t       - Initial guess for location of minimizer.
           30 # %%  params  - Optional struct of control parameters.  See below.
           31 # %%  
           32 # %%  Out:
           33 # %%  status  - Structure with the following fields:
           34 # %%  code    - nonnegative integer with 0 indicating no error
           35 # %%  msg     - if code>0, a descriptive error message about why the algorithm failed
           36 # %%  val       - if code==0, structure containing information about the minimizer
           37 # %%  val.t     - location of minimizer
           38 # %%  val.F     - value of F(c.t)
           39 # %%  val.Fp    - value of F'(c.t)
           40 # %%  val.data  - additional data at minimizer.  See below.
           41 # %%
           42 # %%  The function F must have the following signature: function [f, fp, fdata] = F(t).
           43 # %%  The f and fp are the values of the function and derivative at t.  For some functions,
           44 # %%  there is expensive data that are computed along the way that might be needed by the end user.
           45 # %%  The data field allows this excess data to be saved and returned.  In the end it will show up
           46 # %%  in status.c.data.
           47 # %%
           48 # %%  The various parameters that control the algorithm are described in the above reference and
           49 # %%  have the same name.  The principal ones you might want to change:
           50 # %%
           51 # %%  delta: Controls sufficent decrease Wolfe condition:  delta*F'(0) >= (F(t)-F(0))/t
           52 # %%  sigma: Controls sufficent shallowness Wolfe condition:  F'(t) >= sigma*F'(0)
           53 # %%  rho:   Expansion factor for initial bracket search (bracket expands by multiples of rho)
           54 # %%  nsecant: Maximum number of outermost loops
           55 # %%  nshrink: Maximum number of interval shrinks or expansions in a single loop
           56 # %%  verbose: Print lots of messages to track the algorithm
           57 # %%  debug:   Do extra computations to verify the state is consistant as the algorithm progresses.
           58 # %%
           59 # %%  
           60 # 
           61 
           62 from siple.params import Bunch, Parameters
           63 from siple.reporting import msg, pause
           64 import numpy
           65 
           66 class LinesearchHZ:
           67   
           68   @staticmethod
           69   def defaultParameters():
           70     return Parameters('linesearchHZ', delta=.1, sigma=.9, epsilon=0, theta=.5, gamma=.66,  rho=5,
           71                               nsecant=50, nshrink=50, verbose=False, debug=True);
           72 
           73   def __init__(self,params=None):
           74     self.params = self.defaultParameters()
           75     if not (params is None): self.params.update(params)
           76 
           77   def error(self):
           78     return self.code > 0
           79     
           80   def ezsearch(self,F,t0=None):
           81     self.F = F
           82     z = self.eval(0)
           83     if t0 is None:
           84       t0 = 1./(1.-z.F0p);
           85     return self.search(F,z.F,z.Fp,t0)
           86     
           87   def search(self,F,F0,F0p,t0):
           88     self.code = -1
           89     self.errMsg = 'no error'
           90     self.F = F
           91 
           92     params = self.params
           93 
           94     z = Bunch(F=F0,Fp=F0p,t=0,data=None)
           95     assert F0p <= 0
           96 
           97     # % Set up constants for checking Wolfe conditions.
           98     self.wolfe_lo = params.sigma*z.Fp;
           99     self.wolfe_hi = params.delta*z.Fp;
          100     self.awolfe_hi = (2*params.delta-1)*z.Fp;
          101     self.fpert = z.F + params.epsilon;
          102     self.f0 = z.F;
          103 
          104     if params.verbose: msg('starting at z=%g (%g,%g)', z.t, z.F, z.Fp)
          105 
          106     while True:
          107       c = self.eval(t0)
          108       if not numpy.isnan(c.F):
          109         break
          110       msg('Hit a NaN in initial evaluation at t=%g',t0)        
          111       t0 *= 0.5
          112     
          113     if params.verbose: msg('initial guess c=%g (%g,%g)', c.t, c.F, c.Fp)
          114 
          115     if self.wolfe(c):
          116       if params.verbose: msg('done at init')
          117       self.setDone(c)
          118       return
          119 
          120     (aj,bj) = self.bracket(z,c)
          121     if params.verbose: msg('initial bracket %g %g',aj.t,bj.t)
          122    
          123     if self.code >= 0:
          124       self.doneMsg('initial bracket')
          125       return
          126 
          127     if params.debug: self.verifyBracket(aj,bj)
          128 
          129     count = 0;
          130 
          131     while True:
          132       count += 1;
          133 
          134       if count> params.nsecant:
          135         self.setError('too many bisections in main loop')
          136         return
          137 
          138       (a,b) = self.secantsq(aj,bj);
          139       if params.verbose: msg('secantsq a %g b %g', a.t, b.t)
          140       if params.verbose: self.printBracket(a,b)
          141       if self.code >= 0:
          142         self.doneMsg('secant');
          143         return
          144       
          145       if  (b.t-a.t) > params.gamma*(bj.t-aj.t):
          146         (a,b) = self.update(a, b, (a.t+b.t)/2);      
          147         if params.verbose: msg('update to a %g b %g', aj.t, bj.t)
          148         if params.verbose: self.printBracket(a,b)
          149         if self.code >= 0:
          150           self.doneMsg('bisect');
          151           return
          152       aj = a
          153       bj = b
          154 
          155   def printBracket(self,a,b):
          156     msg('a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
          157 
          158   def doneMsg(self,where):
          159     if self.code > 0:
          160       msg('done at %s with error status: %s', where, self.errMsg);
          161     else:
          162       if self.params.verbose: msg('done at %s with val=%g (%g, %g)', where, self.value.t, self.value.F, self.value.Fp);
          163 
          164   def verifyBracket(self,a,b):
          165     good = (a.Fp<=0) and (b.Fp >=0) and (a.F<= self.fpert);
          166     if not good:
          167       msg('bracket inconsistent: a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
          168       pause()
          169     if (a.t>=b.t):
          170       msg('bracket not a bracket (a>=b): a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
          171 
          172   def setDone(self, c):
          173     self.code = 0
          174     self.value = c
          175     
          176   def setError(self,msg):
          177     self.code = 1
          178     self.errMsg = msg
          179 
          180   def update(self, a, b, ct):
          181     abar = a
          182     bbar = b
          183     
          184     params = self.params
          185     
          186     if params.verbose: msg('update %g %g %g', a.t, b.t, ct);
          187     if  (ct<=a.t) or (ct>=b.t):
          188       if params.verbose: msg('midpoint out of interval')
          189       return (abar,bbar)
          190 
          191     c = self.eval(ct)
          192 
          193     if self.wolfe(c):
          194       self.setDone(c)
          195       return (abar,bbar)
          196 
          197     if c.Fp >= 0:
          198       if params.verbose: msg('midpoint with non-negative slope. Becomes b.')
          199       abar = a;
          200       bbar = c;
          201       if params.debug: self.verifyBracket(abar,bbar)
          202       return (abar,bbar)
          203 
          204     if c.F <= self.fpert:
          205       if params.verbose: msg('midpoint with negative slope, small value. Becomes a.')
          206       abar = c;
          207       bbar = b;
          208       if params.debug: self.verifyBracket(abar,bbar)
          209       return (abar,bbar)
          210 
          211     if params.verbose: msg('midpoint with negative slope, large value. Shrinking to left.')
          212     (abar,bbar) = self.ushrink(a, c);
          213     if params.debug: self.verifyBracket(abar,bbar)
          214     
          215     return (abar,bbar)
          216 
          217   def ushrink(self,a,b):
          218     abar = a;
          219     bbar = b;
          220 
          221     count = 0;
          222     while True:
          223       count += 1;
          224       
          225       if self.params.verbose:
          226         msg('in ushrink')
          227         self.printBracket(abar,bbar)
          228       if count > self.params.nshrink:
          229         self.setError('too many contractions in ushrink')
          230         return (abar,bbar)
          231     
          232       d=self.eval((1-self.params.theta)*abar.t+self.params.theta*bbar.t);
          233       if self.wolfe(d):
          234         self.setDone(d)
          235         return (abar,bbar)
          236 
          237       if d.Fp>=0:
          238         bbar = d;
          239         return (abar,bbar)
          240     
          241       if d.F <= self.fpert:
          242         abar=d;
          243       else:
          244         bbar=d;
          245 
          246   def plotInterval(self,a,b,N=20):
          247     from matplotlib import pyplot as pp
          248     import numpy as np
          249     T=np.linspace(a.t,b.t,N)
          250     FT=[]
          251     FpT=[]
          252     for t in T:
          253       c=self.eval(t)
          254       FT.append(c.F)
          255       FpT.append(c.Fp)
          256     pp.subplot(1,2,1)
          257     pp.plot(T,np.array(FT))
          258     pp.subplot(1,2,2)
          259     pp.plot(T,np.array(FpT))
          260     pp.draw()
          261 
          262   def secant(self,a,b):
          263     # % What if a'=b'?  We'll generate a +/-Inf, which will subsequently test as being out
          264     # % of any interval when 'update' is subsequently called.  So this seems safe.
          265     
          266     if self.params.verbose: msg('secant: a %g fp(a) %4.8g b %g fp(b) %4.8g',a.t,a.Fp, b.t, b.Fp)
          267     if (a.t==b.t):
          268       msg('a=b, inconcievable!')
          269     if -a.Fp <= b.Fp:
          270       return a.t-(a.t-b.t)*(a.Fp/(a.Fp-b.Fp));
          271     else:
          272       return b.t-(a.t-b.t)*((b.Fp)/(a.Fp-b.Fp));
          273 
          274   def secantsq(self,a,b):
          275     ct = self.secant(a,b)
          276     if self.params.verbose: msg('first secant to %g', ct)
          277     (A,B) = self.update(a,b,ct)
          278     if self.code >= 0:
          279       return (A,B)
          280 
          281     if B.t == ct:
          282       ct2 = self.secant(b,B);
          283       if self.params.verbose: msg('second secant on left half A %g B %g with c=%g',A.t, B.t, ct2)
          284       (abar,bbar) = self.update(A,B,ct2)
          285     elif A.t == ct:
          286       ct2 = self.secant(a,A);
          287       if self.params.verbose: msg('second secant on right half A %g B %g with c=%g',A.t, B.t, ct2)
          288       (abar,bbar) = self.update(A,B,ct2)
          289     else:
          290       if self.params.verbose: msg('first secant gave a shrink in update. Keeping A %g B %g',A.t, B.t)
          291       abar = A; bbar = B
          292     
          293     return (abar,bbar)
          294 
          295 
          296   def bracket(self, z, c):
          297     a = z
          298     b = c
          299     
          300     count = 0
          301     while True:
          302       if count > self.params.nshrink:
          303         self.setError('Too many expansions in bracket')
          304         return (a,b)
          305       count += 1
          306 
          307       if b.Fp >= 0:
          308         if self.params.verbose: msg('initial bracket ends with expansion: b has positive slope')
          309         return (a,b)
          310     
          311       if b.F > self.fpert:
          312         if self.params.verbose: msg('initial bracket contraction')
          313         return self.ushrink(a,b);
          314 
          315       if self.params.verbose: msg('initial bracket expanding')
          316       a = b;
          317       rho = self.params.rho
          318       while True:
          319         if count > self.params.nshrink:
          320           self.setError('Unable to find a valid input')
          321           return (a,b)
          322         c = self.eval(rho*b.t)
          323         if not numpy.isnan(c.F):
          324           b = c
          325           break
          326         msg('Hit a NaN at t=%g',rho*b.t)
          327         rho*=0.5
          328         count += 1
          329 
          330       if self.wolfe(b):
          331         #msg('decrease %g slope %g f0 %g fpert %g', b.F-params.f0, b.t*params.wolfe_hi, params.f0, params.fpert)
          332         self.setDone(b);
          333         return(a,b)
          334 
          335   def wolfe(self,c):
          336     if self.params.verbose: msg('checking wolfe of c=%g (%g,%g)',c.t,c.F,c.Fp)
          337 
          338     if c.Fp >= self.wolfe_lo:
          339       if (c.F-self.f0) <= c.t*self.wolfe_hi:
          340         return True
          341 
          342       if self.params.verbose: msg('failed sufficient decrease')
          343 
          344       # % if ((c.F <= params.fpert) && (c.Fp <= params.awolfe_hi))
          345       # %   msg('met awolfe')
          346       # %   met = true;
          347       # %   return;
          348       # % end
          349       # if params.verbose: msg('failed awolfe sufficient decrease')
          350     else:
          351       if self.params.verbose: msg('failed slope flatness')
          352 
          353     return False
          354 
          355   def eval(self,t):
          356     c = Bunch(F=0,Fp=0,data=None,t=t)
          357     (c.F,c.Fp,c.data) = self.F(t)
          358     c.F = float(c.F)
          359     c.Fp = float(c.Fp)
          360     return c
          361 
          362 if __name__ == '__main__':
          363 
          364   lsParams = Parameters('tmp', verbose=True, debug=True)
          365   ls = LinesearchHZ(params=lsParams)
          366   F = lambda t: (-t*(1-t), -1+2*t,None)
          367   ls.ezsearch(F,5)
          368   if ls.error():
          369     print(ls.errMsg)
          370   else:
          371     v = ls.value
          372     print('minimum of %g at t=%g' % (v.F,v.t))