tssa_tao.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
 (HTM) git clone git://src.adamsgaard.dk/pism
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) LICENSE
       ---
       tssa_tao.py (10577B)
       ---
            1 # Copyright (C) 2012, 2014, 2015, 2016, 2018 David Maxwell and Constantine Khroulev
            2 #
            3 # This file is part of PISM.
            4 #
            5 # PISM is free software; you can redistribute it and/or modify it under the
            6 # terms of the GNU General Public License as published by the Free Software
            7 # Foundation; either version 3 of the License, or (at your option) any later
            8 # version.
            9 #
           10 # PISM is distributed in the hope that it will be useful, but WITHOUT ANY
           11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
           12 # FOR A PARTICULAR PURPOSE.  See the GNU General Public License for more
           13 # details.
           14 #
           15 # You should have received a copy of the GNU General Public License
           16 # along with PISM; if not, write to the Free Software
           17 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
           18 
           19 """Inverse SSA solvers using the TAO library."""
           20 
           21 import PISM
           22 from PISM.util import Bunch
           23 from PISM.logging import logError
           24 from PISM.invert.ssa import InvSSASolver
           25 
           26 import sys
           27 import traceback
           28 
           29 
           30 class InvSSASolver_Tikhonov(InvSSASolver):
           31 
           32     """Inverse SSA solver based on Tikhonov iteration using TAO."""
           33 
           34     # Dictionary converting PISM algorithm names to the corresponding
           35     # TAO algorithms used to implement the Tikhonov minimization.
           36     tao_types = {}
           37 
           38     if (not PISM.imported_from_sphinx) and PISM.PETSc.Sys.getVersion() < (3, 5, 0):
           39         tao_types = {'tikhonov_lmvm': 'tao_lmvm',
           40                      'tikhonov_cg': 'tao_cg',
           41                      'tikhonov_lcl': 'tao_lcl',
           42                      'tikhonov_blmvm': 'tao_blmvm'}
           43     else:
           44         tao_types = {'tikhonov_lmvm': 'lmvm',
           45                      'tikhonov_cg': 'cg',
           46                      'tikhonov_lcl': 'lcl',
           47                      'tikhonov_blmvm': 'blmvm'}
           48 
           49 
           50     def __init__(self, ssarun, method):
           51         """
           52         :param ssarun: The :class:`PISM.invert.ssa.SSAForwardRun` defining the forward problem.
           53         :param method: String describing the actual algorithm to use. Must be a key in :attr:`tao_types`."""
           54 
           55         InvSSASolver.__init__(self, ssarun, method)
           56         self.listeners = []
           57         self.solver = None
           58         self.ip = None
           59         if self.tao_types.get(method) is None:
           60             raise ValueError("Unknown TAO Tikhonov inversion method: %s" % method)
           61 
           62     def addIterationListener(self, listener):
           63         """Add a listener to be called after each iteration.  See :ref:`Listeners`."""
           64         self.listeners.append(listener)
           65 
           66     def addDesignUpdateListener(self, listener):
           67         """Add a listener to be called after each time the design variable is changed."""
           68         self.listeners.append(listener)
           69 
           70     def solveForward(self, zeta, out=None):
           71         r"""Given a parameterized design variable value :math:`\zeta`, solve the SSA.
           72         See :cpp:class:`IP_TaucParam` for a discussion of parameterizations.
           73 
           74         :param zeta: :cpp:class:`IceModelVec` containing :math:`\zeta`.
           75         :param out: optional :cpp:class:`IceModelVec` for storage of the computation result.
           76         :returns: An :cpp:class:`IceModelVec` contianing the computation result.
           77         """
           78         ssa = self.ssarun.ssa
           79 
           80         reason = ssa.linearize_at(zeta)
           81         if reason.failed():
           82             raise PISM.AlgorithmFailureException(reason)
           83         if out is not None:
           84             out.copy_from(ssa.solution())
           85         else:
           86             out = ssa.solution()
           87         return out
           88 
           89     def solveInverse(self, zeta0, u_obs, zeta_inv):
           90         r"""Executes the inversion algorithm.
           91 
           92         :param zeta0: The best `a-priori` guess for the value of the parameterized design variable :math:`\zeta`.
           93         :param u_obs: :cpp:class:`IceModelVec2V` of observed surface velocities.
           94         :param zeta_inv: :cpp:class:`zeta_inv` starting value of :math:`\zeta` for minimization of the Tikhonov functional.
           95         :returns: A :cpp:class:`TerminationReason`.
           96         """
           97         eta = self.config.get_number("inverse.tikhonov.penalty_weight")
           98 
           99         design_var = self.ssarun.designVariable()
          100         if design_var == 'tauc':
          101             if self.method == 'tikhonov_lcl':
          102                 problemClass = PISM.IP_SSATaucTaoTikhonovProblemLCL
          103                 solverClass = PISM.IP_SSATaucTaoTikhonovProblemLCLSolver
          104                 listenerClass = TaucLCLIterationListenerAdaptor
          105             else:
          106                 problemClass = PISM.IP_SSATaucTaoTikhonovProblem
          107                 solverClass = PISM.IP_SSATaucTaoTikhonovSolver
          108                 listenerClass = TaucIterationListenerAdaptor
          109         elif design_var == 'hardav':
          110             if self.method == 'tikhonov_lcl':
          111                 problemClass = PISM.IP_SSAHardavTaoTikhonovProblemLCL
          112                 solverClass = PISM.IP_SSAHardavTaoTikhonovSolverLCL
          113                 listenerClass = HardavLCLIterationListenerAdaptor
          114             else:
          115                 problemClass = PISM.IP_SSAHardavTaoTikhonovProblem
          116                 solverClass = PISM.IP_SSAHardavTaoTikhonovSolver
          117                 listenerClass = HardavIterationListenerAdaptor
          118         else:
          119             raise RuntimeError("Unsupported design variable '%s' for InvSSASolver_Tikhonov. Expected 'tauc' or 'hardness'" % design_var)
          120 
          121         tao_type = self.tao_types[self.method]
          122         (stateFunctional, designFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun)
          123 
          124         self.ip = problemClass(self.ssarun.ssa, zeta0, u_obs, eta, stateFunctional, designFunctional)
          125         self.solver = solverClass(self.ssarun.grid.com, tao_type, self.ip)
          126 
          127         max_it = int(self.config.get_number("inverse.max_iterations"))
          128         self.solver.setMaximumIterations(max_it)
          129 
          130         pl = [listenerClass(self, l) for l in self.listeners]
          131 
          132         for l in pl:
          133             self.ip.addListener(l)
          134 
          135         self.ip.setInitialGuess(zeta_inv)
          136 
          137         vecs = self.ssarun.modeldata.vecs
          138         if vecs.has('zeta_fixed_mask'):
          139             self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask)
          140 
          141         return self.solver.solve()
          142 
          143     def inverseSolution(self):
          144         """Returns a tuple ``(zeta,u)`` of :cpp:class:`IceModelVec`'s corresponding to the values
          145         of the design and state variables at the end of inversion."""
          146         zeta = self.ip.designSolution()
          147         u = self.ip.stateSolution()
          148         return (zeta, u)
          149 
          150 
          151 class TaucLCLIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemLCLListener):
          152 
          153     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
          154     on to a standard python-based listener.  Used internally by
          155     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
          156 
          157     def __init__(self, owner, listener):
          158         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
          159            :param listener: The python-based listener.
          160          """
          161         PISM.IP_SSATaucTaoTikhonovProblemLCLListener.__init__(self)
          162         self.owner = owner
          163         self.listener = listener
          164 
          165     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, constraints):
          166         """Called during IP_SSATaucTaoTikhonovProblemLCL iterations.  Gathers together the long list of arguments
          167         into a dictionary and passes it along in standard form to the python listener."""
          168 
          169         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
          170                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
          171                      u=u, residual=diff_u, grad_JState=grad_u,
          172                      constraints=constraints)
          173         try:
          174             self.listener(self.owner, it, data)
          175         except Exception:
          176             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
          177             traceback.print_exc(file=sys.stdout)
          178             raise
          179 
          180 
          181 class TaucIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemListener):
          182 
          183     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
          184     on to a standard python-based listener.  Used internally by
          185     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
          186 
          187     def __init__(self, owner, listener):
          188         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
          189            :param listener: The python-based listener.
          190          """
          191         PISM.IP_SSATaucTaoTikhonovProblemListener.__init__(self)
          192         self.owner = owner
          193         self.listener = listener
          194 
          195     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
          196         """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
          197         into a dictionary and passes it along in a standard form to the python listener."""
          198         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
          199                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
          200                      u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
          201         try:
          202             self.listener(self.owner, it, data)
          203         except Exception:
          204             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
          205             traceback.print_exc(file=sys.stdout)
          206             raise
          207 
          208 
          209 class HardavIterationListenerAdaptor(PISM.IP_SSAHardavTaoTikhonovProblemListener):
          210 
          211     """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
          212     on to a standard python-based listener.  Used internally by
          213     :class:`InvSSATaucSolver_Tikhonov`.  I.e. don't make one of these for yourself."""
          214 
          215     def __init__(self, owner, listener):
          216         """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
          217            :param listener: The python-based listener.
          218          """
          219         PISM.IP_SSAHardavTaoTikhonovProblemListener.__init__(self)
          220         self.owner = owner
          221         self.listener = listener
          222 
          223     def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
          224         """Called during IP_SSATaucTaoTikhonovProblem iterations.  Gathers together the long list of arguments
          225         into a dictionary and passes it along in a standard form to the python listener."""
          226         data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
          227                      zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
          228                      u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
          229         try:
          230             self.listener(self.owner, it, data)
          231         except Exception:
          232             logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
          233             traceback.print_exc(file=sys.stdout)
          234             raise