Package cc.mallet.types

Examples of cc.mallet.types.LogNumber


  private double runForward(CRF crf, ArrayList<GEConstraint> constraints1, ArrayList<GEConstraint> constraints2, double[][] gammas,
      double[][][] xis, int[][] reverseTrans, FeatureVectorSequence fvs) {
    double dotEx = 0;
 
    LogNumber[] oneStateValueCache = new LogNumber[numStates];
    LogNumber nuAlpha = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
   
    for (int ip = 0; ip < latticeLength-1; ++ip) {
      FeatureVector fv = fvs.get(ip);
      // speed things up by giving the constraints an
      // opportunity to cache, for example, which
      // constrained input features appear in this
      // FeatureVector
      for (GEConstraint constraint : constraints1) {
        constraint.preProcess(fv);
      }
      for (GEConstraint constraint : constraints2) {
        constraint.preProcess(fv);
      }
     
      boolean[] oneStateValComputed = new boolean[numStates];
      for (int prev = 0; prev < numStates; prev++) {
        nuAlpha.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        if (ip != 0) {
          int[] prevPrevs = reverseTrans[prev];
          // calculate only once: \sum_y_{i-1} w_a(y_{i-1},y_i)
          for (int ppi = 0; ppi < prevPrevs.length; ppi++) {
            nuAlpha.plusEquals(lattice[ip-1][prevPrevs[ppi]].alpha[prev]);
          }
        }

        assert (!Double.isNaN(nuAlpha.logVal));

        CRF.State prevState = (CRF.State)crf.getState(prev);
        LatticeNode node = lattice[ip][prev];
        double[] xi = xis[ip][prev];
        double gamma = gammas[ip][prev];

        for (int ci = 0; ci < prevState.numDestinations(); ci++) {
          int curr = prevState.getDestinationState(ci).getIndex();
          double dot = 0;
          for (GEConstraint constraint : constraints2) {
            dot += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
          }

          // avoid recomputing one-state constraint features #labels times
          if (!oneStateValComputed[curr]) {
            double osVal = 0;
            for (GEConstraint constraint : constraints1) {
              osVal += constraint.getCompositeConstraintFeatureValue(fv, ip, prev, curr);
            }
            if (osVal < 0) {
              dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
              oneStateValueCache[curr] = new LogNumber(Math.log(-osVal),false);
            }
            else if (osVal > 0) {
              dotEx += Math.exp(gammas[ip+1][curr]) * osVal;
              oneStateValueCache[curr] = new LogNumber(Math.log(osVal),true);
            }
            else {
              oneStateValueCache[curr] = null;
            }
            oneStateValComputed[curr] = true;
          }
         
          // combine the one and two state constraint feature values
          if (dot == 0 && oneStateValueCache[curr] == null) {
            dotCache[ip][prev][curr] = null;
          }
          else if (dot == 0 && oneStateValueCache[curr] != null) {
            dotCache[ip][prev][curr] = oneStateValueCache[curr];
          }
          else {
            dotEx += Math.exp(xi[curr]) * dot;
            if (dot < 0) {
              dotCache[ip][prev][curr] = new LogNumber(Math.log(-dot),false);
            }
            else {
              dotCache[ip][prev][curr] = new LogNumber(Math.log(dot),true);
            }
            if (oneStateValueCache[curr] != null) {
              dotCache[ip][prev][curr].plusEquals(oneStateValueCache[curr]);
            }
          }
         
          // update the dynamic programming table
          if (dotCache[ip][prev][curr] != null) {
            temp.set(xi[curr],true);
            temp.timesEquals(dotCache[ip][prev][curr]);
            node.alpha[curr].plusEquals(temp);
          }
          if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
            node.alpha[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
          } else {
            temp.set(xi[curr] - gamma,true);
            temp.timesEquals(nuAlpha);
            node.alpha[curr].plusEquals(temp);
          }
          assert (!Double.isNaN(node.alpha[curr].logVal)) : "xi: " + xi[curr] + ", gamma: "
              + gamma + ", constraint feature: " + dotCache[ip][prev][curr]
              + ", nuApha: " + nuAlpha + " dot: " + dot;
View Full Code Here


   * @return
   */
  private void runBackward(CRF crf, double[][] gammas, double[][][] xis, int[][] reverseTrans, int[][] reverseTransIndices,
      FeatureVectorSequence fvs, double dotEx, CRF.Factors gradient) {
   
    LogNumber nuBeta = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber dot = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber temp2 = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
    LogNumber nextDot;
   
    for (int ip = latticeLength-2; ip >= 0; --ip) {
      for (int curr = 0; curr < numStates; ++curr) {

        nuBeta.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        dot.set(Transducer.IMPOSSIBLE_WEIGHT,true);
        // calculate only once: \sum_y_{i+1} w_b(y_i,y+i)
       
       
        CRF.State currState = (CRF.State)crf.getState(curr);
        for (int ni = 0; ni < currState.numDestinations(); ni++){
          int next= currState.getDestinationState(ni).getIndex();
          nuBeta.plusEquals(lattice[ip+1][curr].beta[next]);
          assert(!Double.isNaN(nuBeta.logVal));

          nextDot = dotCache[ip+1][curr][next];
          if (nextDot != null) {
            double xi = xis[ip+1][curr][next];
            temp.set(xi,true);
            temp.timesEquals(nextDot);
            dot.plusEquals(temp);
          }
        }

        double gamma = gammas[ip+1][curr];

        int[] prevStates = reverseTrans[curr];
        for (int pi = 0; pi < prevStates.length; pi++) {
          int prev = prevStates[pi];
         
          CRF.State crfState = (CRF.State)crf.getState(prev);

          LatticeNode node = lattice[ip][prev];
          double xi = xis[ip][prev][curr];

          if (gamma == Transducer.IMPOSSIBLE_WEIGHT) {
            node.beta[curr] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
          } else {
            // constraint feature values cached in Forward pass
            temp.set(dot.logVal,dot.sign);
            temp.plusEquals(nuBeta);
            temp2.set(xi-gamma,true);
View Full Code Here

    public LatticeNode() {
      alpha = new LogNumber[numStates];
      beta = new LogNumber[numStates];
      for (int si = 0; si < numStates; ++si) {
        alpha[si] = new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
        beta[si] new LogNumber(Transducer.IMPOSSIBLE_WEIGHT,true);
      }
    }
View Full Code Here

TOP

Related Classes of cc.mallet.types.LogNumber

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.