require "numru/gphys"
require "numru/dclext"  # for SHTLIB and its extensions (sht_get_n, sht_get_m)

module NumRu
  module GAnalysis
    
    # Spherical harmonics library for equally spaced lon-lat grids using SHTLIB in DCL
    # 
    # REQUIREMENT
    # * This module can handle grid data that satisfy the following
    #   * dimension 0 (first dim) is longitude, dimension 1 is latitude.
    #   * must be equally spaced and cover the globe
    #     such as lon = 0, 2,.., 358 [,360], or -180, -170,..,170 [,180].
    #     etc (in increasing order; cyclic extension is applied internally) 
    #     and lat = -90,-88,..,90. (or 90, 98,..,-90; pole to pole)
    # * Data missig is not allowed
    # 
    module SphHarmonicIsoG
      @@work = @@im = @@jm = @@mm = nil

      module_function

      # Check the valdity of the lon-lat grid. 
      # 
      # This method raises an exception if the data is not acceptable
      # 
      # ARGUMENTS
      # * gp [GPhys] : data to check its grid
      # * mm [Integer] : truncation wavenumber
      #   
      # RETURN VALUE
      # * gphys [GPhys] : same as input, but for the 1st dim is cyclically 
      #   extended if needed.
      # * np2sp [true or false] : true if data is N.pole to S.pole,
      #   false if data is S.pole to N.pole
      def check_and_init(gphys, mm)

        raise(ArgumentError, "Invalid rank (#{gphys.rank})") if gphys.rank < 2

        lon = gphys.coord(0).val
        if lon[0] > lon[-1]
          raise ArgumentError, "Longitude must be in the increasing order."
        end
        gphys = gphys.cyclic_ext(0)

        lat = gphys.coord(1).convert_units("degrees").val
        eps = 0.1
        if (lat[0]-90.0).abs < eps &&  (lat[-1]+90.0).abs < eps
          # N.pole to S.pole
          np2sp = true
        elsif (lat[0]+90.0).abs > eps ||  (lat[-1]-90.0).abs > eps
          raise ArgumentError, "Not pole to pole: #{lat[0]..lat[-1]}."
        else
          np2sp = false
        end

        nx, ny, = gphys.shape
        im = (nx-1)/2
        jm = (ny-1)/2

        if (mm+1)/2 > jm || mm+1 > im
          raise(ArgumentError,"mm=#{mm} is too big for im=#{im} & jm=#{jm}")
        end
        if @@work.nil? || @@mm != mm || @@jm != jm || @@im != im
          @@work = DCL.shtint(mm,jm,im)
          @@mm = mm
          @@jm = jm
          @@im = im
        end
        [gphys, np2sp]
      end

      # Spherical harmonics filter
      #
      # ARGUMENTS
      # * gp [GPhys] : grid data to filter (lon,lat,...: rank >= 2)
      # * mm [Integer] : truncation wavenumber
      # * deriv (optional) [nil(default) or Symbol]
      #   (let \lambda be longitude and \phi be latitude)
      #   if :xgrad (or :xdiv), applies del / cos\phi del\lambda,
      #   if :ygrad, applies del / del\phi,
      #   if :ydiv, applies del cos\phi / cos\phi del\phi.
      # * lap (optional) [nil(default) or Integer] If 1, laplacian
      #   is taken; if -1 inverserser laplacian is taken -- note
      #   that you can also explitly take laplacian by
      #   using the factor_n_m block (see below).
      #
      # * factor_n_m (optional block)
      #   spectral filter in the form of {|n,m|  } that
      #   returns the factor if n and m are integers.
      #   Example {|n,m| -n*(n+1)} to take laplacian. 
      def filt( gp, mm, deriv=nil, lap=nil, &factor_n_m )
        iswg2s = isws2g = 0
        case deriv
        when :xgrad, :xdiv
          isws2g = -1
        #when :xderiv
        ##(horinout, developpers memo: just multiply with cos\phi afterward)
        #  iswg2s = -1   # don't know if isws2g is better
        #  raise("Under development: need cos_phi factor")
        when :ygrad
          isws2g = 1
        when :ydiv
          iswg2s = 1
        when nil
          # do nothing
        else
          raise ArgumentError,"Unsupported operation #{deriv}"
        end
        nlon = gp.shape[0]
        gp, np2sp = check_and_init(gp, mm)   # sets @@work, @@mm, @@im & @@jm
        na = gp.val
        na = na.to_na if na.respond_to?(:to_na)   # for NArrayMiss
        shape = na.shape
        naf = NArray.float( *shape )  # output variable (filtered NArray)
        f = __factor_n_m( &factor_n_m ) if factor_n_m
        loop_for_3rd_or_greater_dim(shape){|sel|
          w,s = DCL.shtg2s(@@mm,iswg2s,na[*sel],@@work)
          s = DCL.shtlap(@@mm,lap,s) if lap
          s = s * f if factor_n_m
          s = -s if np2sp && (isws2g==1 || iswg2s==1 )  # y-reversed --> *(-1)
          w,g =  DCL.shts2g(@@mm,@@jm,@@im,isws2g,s,@@work)
          naf[*sel] = g
        }
        vaf = VArray.new(naf, gp.data, gp.name)
        gp = GPhys.new( gp.grid, vaf)
        if gp.shape[0] != nlon
          # cyclically extended --> trim it to the original shape
          gp = gp[0...nlon,false]
        end
        gp
      end

      def xgrad( gp, mm, &factor_n_m )
        filt( gp, mm, :xgrad, &factor_n_m )
      end
      alias :xdiv :xgrad
      module_function :xdiv

      def ygrad( gp, mm, &factor_n_m )
        filt( gp, mm, :ygrad, &factor_n_m )
      end
      def ydiv( gp, mm, &factor_n_m )
        filt( gp, mm, :ydiv, &factor_n_m )
      end

      # Horizontal Laplacian on the sphere
      # 
      # * gp [GPhys] : grid data (lon,lat,...: rank >= 2)
      # * mm [Integer] : truncation wavenumber
      # * radius (optional; defaut=1.0) [Numeric(if non-dim) or UNumeric]
      #   radius of the sphere
      def lapla_h( gp, mm, radius=1.0, order=1 )
        if order==-1 || order==1
          gp = filt( gp, mm, nil, order )
        elsif order > 0
          gp = filt( gp, mm ){|n,m| (-n*(n+1))**order }
        else
          # negative --> avoid zero division
          gp = filt( gp, mm ){|n,m| n==0 ? 0 : (-n*(n+1))**order }
        end
        gp *= radius**(-2) if radius != 1.0
        gp
      end

      # Horizontal divergence on the sphere
      # 
      # ARGUMENTS
      # * u,v [GPhys] : the x and y components to take divergence
      # * mm [Integer] : truncation wavenumber
      # * radius (optional; defaut=1.0) [Numeric(if non-dim) or UNumeric]
      #   radius of the sphere
      def div_h(u,v, mm, radius=1.0, &factor_n_m )
        gp = xdiv(u, mm, &factor_n_m) + ydiv(v, mm, &factor_n_m)
        gp *= radius**(-1) if radius != 1.0
        gp.long_name = "div_h(#{u.name},#{v.name})"
        gp.name = "div_h"
        gp
      end

      # Horizontal rotation on the sphere
      # 
      # * u,v [GPhys] : the x and y components to take rotation
      # * mm [Integer] : truncation wavenumber
      # * radius (optional; defaut=1.0) [Numeric(if non-dim) or UNumeric]
      #   radius of the sphere
      def rot_h(u,v, mm, radius=1.0, &factor_n_m )
        gp = xdiv(v, mm, &factor_n_m) - ydiv(u, mm, &factor_n_m)
        gp *= radius**(-1) if radius != 1.0
        gp.long_name = "rot_h(#{u.name},#{v.name})"
        gp.name = "rot_h"
        gp
      end

      def __factor_n_m( &factor_n_m )
        ms = DCLExt.sht_get_m(@@mm)
        ns = DCLExt.sht_get_n(@@mm)
        len = ms.length
        f = NArray.float(len)
        for i in 0...len
          f[i] = factor_n_m.call(ns[i],ms[i])
        end
        f
      end

      def loop_for_3rd_or_greater_dim(shape,&block)
        raise(ArgumentError, "block not given") if !block
        sh3 = shape[2..-1]
        rank3 = sh3.length
        csh3 = [1]
        (1...rank3).each{|d| csh3[d] = sh3[d-1]*csh3[d-1]}
        len = 1
        sh3.each{|n| len *= n}
        for i in 0...len
          sel = [true,true]
          (0...rank3).each do |d|
            sel.push( (i/csh3[d]) % sh3[d] )
          end
          block.call(sel)
        end
      end

    end

  end
end

#######################################################
# test / demo part
#######################################################
if __FILE__ == $0
  require "numru/ggraph"
  include NumRu
  include GAnalysis
  include NMath
  gp = GPhys::IO.open("../../../testdata/T.jan.nc","T").copy
  x = gp.coord(0).val.newdim(1) * (PI/180.0)
  y = gp.coord(1).val.newdim(0) * (PI/180.0)
  #gp[false,-1] = sin(x)*cos(y)
  gp[false,-1] = sin(x)*cos(y)**2
  #gp[false,-1] = (x*0+1)*cos(y) # + 1
  mm = 17
  c = 1.0 / ( mm*mm*(mm+1)*(mm+1) )
  gpf = SphHarmonicIsoG.filt(gp,mm)
  p gpf

  gpf2 = SphHarmonicIsoG.filt(gp,mm){|n,m| 
    [0.0, 1.0 - c*n*n*(n+1)*(n+1)].max
  }

  gpfxg = SphHarmonicIsoG.xgrad(gp,mm)
  gpfyg = SphHarmonicIsoG.ygrad(gp,mm)
  gpfyd = SphHarmonicIsoG.ydiv(gp,mm)

  gpflap = SphHarmonicIsoG.lapla_h(gp,mm)

  DCL.swpset('iwidth',960)
  DCL.swpset('iheight',720)
  DCL.sgscmn(10)
  DCL.gropn(1)
  DCL.sldiv("t",2,3)
  DCL.sgpset("lfull",true)
  DCL.sgpset('isub', 96)      # control character of subscription: '_' --> '`'
  DCL.glpset('lmiss',true)
  GGraph.set_fig "itr"=>10,"viewport"=>[0.15,0.85,0.07,0.4]
  GGraph.set_tone "color_bar"=>true
  GGraph.tone gp
  GGraph.tone gpf, true, "keep"=>true
  GGraph.tone gpf2, true, "keep"=>true

  GGraph.tone gpfxg[false,1], true
  k=-1
  GGraph.tone gpf[false,k], true
  GGraph.tone gpfxg[false,k], true
  #GGraph.tone gpfx[false,k]-gpfxg[false,k], true

  GGraph.tone gpfyg[false,0], true
  k=-1
  GGraph.tone gpf[false,k], true
  GGraph.tone gpfyg[false,k], true, "min"=>-1,"max"=>1
  GGraph.tone gpfyd[false,k], true, "min"=>-1.5,"max"=>1.5
  GGraph.tone gpfyg[false,k]*1.5-gpfyd[false,k], true, "log"=>true

  GGraph.tone gpflap[false,k]

  GGraph.set_fig "itr"=>1
  GGraph.line gpf.cut(true,40,false), true
  GGraph.line gpfxg.cut(true,40,false), true
  DCL.grcls
end
