require "PMDArray"
require "dcl"    # dcl library interface

class SData

# SData -- self-descriptive scalar data (rectangular)

  def initialize(main,*axes)
    if(main.is_a?(PMDArray)) then
      @main=main
    else
      raise(RuntimeError,"main not a PMDArray")
    end
    @nd=@main.ndims
    if (@nd != axes.length) then
      raise(RuntimeError,"number of the axes does not agree with the rank")
    end
    @axes=[]
    for i in 0..@nd-1
      if(axes[i].is_a?(PMDArray)) then
	@axes=@axes+[axes[i]]
      else
	raise(RuntimeError,"axes[#{i}] not a PMDArray")
      end
    end
  end

  def [](*idx)
    # not completed -- so far, cannot deal with the rubber dimensions
    out=self.dup
    out.instance_eval{@main = @main[*idx]}
    axes=Array.new(@nd)
    for d in 0..@nd-1
      axes[d] = @axes[d][idx[d]]
    end    
    out.instance_eval{@axes = axes}
    return out
  end


  ## redifined methods (those applied simply to @main) ##

  def shape; @main.shape; end
  def ndims; @nd; end
  def name; @main.name; end
  def name=(s); @main.name=s; end
  def units; @main.units; end
  def units=(s); @main.units=s; end

  ## redefined methods (others) ##

  def +(a)
    r=self.dclone
    main = @main + (a.is_a?(SData) ? a.main: a)
    r.instance_eval{@main = main}
    return r
  end
  def -(a)
    r=self.dclone
    main = @main - (a.is_a?(SData) ? a.main: a)
    r.instance_eval{@main = main}
    return r
  end
  def *(a)
    r=self.dclone
    main = @main * (a.is_a?(SData) ? a.main: a)
    r.instance_eval{@main = main}
    return r
  end
  def /(a)
    r=self.dclone
    main = @main / (a.is_a?(SData) ? a.main: a)
    r.instance_eval{@main = main}
    return r
  end
  def **(a)
    r=self.dclone
    main=@main**a
    r.instance_eval{@main=main}
    return r
  end

  ##
  def dclone
    out=self.dup
    copy_list=['@main','@axes']
    copy_list.each{ |v|
      out.instance_eval("#{v} = #{v}.clone")
    }
    return out
  end

  def trim!
    @axes.delete_if{|i| i.length == 1}
    @main.trim!
    @nd=@axes.length
  end

  def trim
    out=self.dclone
    out.trim!
    return out
  end

  ## new methods ##

  def main; @main.dclone; end
  def ax(dim); @axes[dim].dclone; end

  def cont

    #def cont(*idx)
    #  if idx.length !=0 then
    #    # subset
    #    obj=self[*idx]
    #    obj.trim!
    #   else
    #    obj=self
    #  end

    if @nd != 2 then
      raise(RuntimeError,"Not a 2D array. Specify a 2D subset")
    end

    # graphics

    Dcl.gropn(1)
    Dcl.grfrm()
    Dcl.grswnd(@axes[0].min, @axes[0].max, @axes[1].min, @axes[1].max)
    Dcl.grsvpt(0.2, 0.8, 0.2, 0.8)
    Dcl.grstrn(1)
    Dcl.grstrf()
    Dcl.usdaxs()
    Dcl.uxmttl('T',@main.name,   0.0)
    Dcl.uxsttl('B',@axes[0].name,0.0)
    Dcl.uysttl('L',@axes[1].name,0.0)
    Dcl.uwsgxa(@axes[0].to_a,shape[0])
    Dcl.uwsgya(@axes[1].to_a,shape[1])
    Dcl.udcntr(@main.to_a, shape[0],shape[0],shape[1])
    Dcl.grcls()
    
  end

  def deriv(dim)
    # numerical derivative based on the differentiation btwn adjacent elements
    # with respect to the dim-th dimension.
    # NOTE: length of the dimension becomes one smaller
    main=@main.dif(dim)/@axes[dim].dif(0)
    axes=@axes.dup
    axes[dim]=@axes[dim].zcen(0)

    out=self.dup
    out.instance_eval{@main=main}
    out.instance_eval{@axes=axes}
    return out
  end

  def cderiv(dim)
    # numerical derivative based on a centered differentiation
    # --- btwn the one before and after except at the ends ---
    # with respect to the dim-th dimension.
    # NOTE: length of the dimension is preserved.
    main=@main.cdif(dim)/@axes[dim].cdif(0)
    axes=@axes.dup

    out=self.dup
    out.instance_eval{@main=main}
    out.instance_eval{@axes=axes}
    return out
  end

end

####### test for development ########

if __FILE__ == $0
  nlon=30
  nlat=15
  ntime=2

  londim={'lon',nlon}
  latdim={'lat',nlat}
  timedim={'time',ntime}

  alon=NMDArray.indgen(nlon)/(nlon) * (2*PI)
  alat=NMDArray.indgen(nlat)/(nlat-1.0) * (PI) - PI/2
  data = alon.rebin(nlon,nlat).sin  *  alat.rebin(nlat,nlon).transpose.cos

  main=PMDArray.new('T','K',londim,latdim,timedim)
  main.import1d(data.to_a.repeat(ntime))

  lon=PMDArray.new('lon','degrees_east',londim)
  lon.import1d(NArray.indgen(lon.length, 0.0, 360.0/lon.length))
  lat=PMDArray.new('lat','degrees_north',latdim)
  lat.import1d(NArray.indgen(lat.length, -90.0, 180.0/(lat.length-1)))
  time=PMDArray.new('time','days',timedim)
  time.import1d(NArray.indgen(time.length, 0.0, 1.0))

  temp=SData.new(main,lon,lat,time)
  #temp.main.prt

  slice=temp[0..-1,0..-1,0]
  ss=slice.trim
  ss.cont
  #(ss**2).cont

  ss_xd=ss.deriv(1)
  ss_xd.cont
  p ss_xd.shape
  ss_xd.ax(0).prt
  ss_xd.ax(1).prt

  ss.cderiv(1).cont

end
