Package org.renjin.stats.internals.models

Source Code of org.renjin.stats.internals.models.FormulaInterpreterTest

package org.renjin.stats.internals.models;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.junit.Assert.assertThat;

import java.util.List;

import org.junit.Test;
import org.renjin.eval.EvalException;
import org.renjin.parser.RParser;
import org.renjin.sexp.ExpressionVector;
import org.renjin.sexp.FunctionCall;
import org.renjin.sexp.SEXP;
import org.renjin.sexp.Symbol;


import org.renjin.stats.internals.models.Formula;
import org.renjin.stats.internals.models.FormulaInterpreter;
import org.renjin.stats.internals.models.Term;

import com.google.common.collect.Lists;

public class FormulaInterpreterTest {

  @Test
  public void simple() {
   
    assertThat(build("y ~ x"), equalTo(formula("y", terms("x"))));
    assertThat(build("y ~ x + y"), equalTo(formula("y", terms("x", "y"))));
    assertThat(build("y ~ x + y + x:y"), equalTo(formula("y", terms("x", "y", interaction("x", "y")))));
  }

  @Test
  public void withArithmatic() {
    assertThat(build("y ~ log(x)"), equalTo(formula("y", terms("log(x)"))));
    assertThat(build("y ~ I(x+1)"), equalTo(formula("y", terms("I(x+1)"))));
  }

  @Test
  public void expansion() {
    assertThat(build("y ~ a * b"), equalTo(formula("y", terms("a","b", interaction("a","b")))));
    assertThat(build("y ~ a:c * b"), equalTo(formula("y", terms("b",interaction("a","c"), interaction("a","c","b")))));
  }
 
  @Test
  public void groupedExpansion() {
    assertThat(build("y ~ (a+b) * c"), equalTo(formula("y", terms("a","b", "c", interaction("a","c"), interaction("b","c")))));
    assertThat(build("y ~ (a+b) * (c+d)"), equalTo(formula("y",
        terms("a","b", "c", "d",
             interaction("a","c"),
             interaction("a","d"),
             interaction("b","c"),
             interaction("b","d")))));
  }
 
  @Test
  public void removeIntercept() {
    assertThat(build("y ~ x - 1"), equalTo(formula("y", 0, terms("x"))));
    assertThat(build("y ~ 1 - 1"), equalTo(formula("y", 0, terms())));
    assertThat(build("y ~ (-1)"), equalTo(formula("y", 0, terms())));
    assertThat(build("y ~ 1 - 1 + 1"), equalTo(formula("y", 1, terms())));

  }

  @Test
  public void substractingTerms() {
    assertThat(build("y ~ a * b - b"), equalTo(formula("y", terms("a", interaction("a","b")))));
  }
 
  @Test(expected=EvalException.class)
  public void invalidIntercept() {
    build("y ~ 6");
  }
 
  private Formula build(String source) {
    SEXP expr = parse(source);
    return new FormulaInterpreter().interpret((FunctionCall) expr);
  }

  private SEXP parse(String source) {
    ExpressionVector tree = RParser.parseSource(source + "\n");
    SEXP expr = tree.getElementAsSEXP(0);
    return expr;
  }
 
  private Formula formula(String response, List<Term> terms) {
    return new Formula(Symbol.get(response), 1, terms);
  }
 
  private Formula formula(String response, int intercept, List<Term> terms) {
    return new Formula(Symbol.get(response), intercept, terms);
  }
 
 
  private List<Term> terms(Object...terms) {
    List<Term> list = Lists.newArrayList();
    for(Object term : terms) {
      if(term instanceof String) {
        list.add(new Term(parse((String)term)));
      } else if(term instanceof Term) {
        list.add((Term)term);
      } else {
        throw new IllegalArgumentException(term.toString());
      }
    }
    return list;
  }
 
  private Term interaction(String... variableNames) {
    List<SEXP> variables = Lists.newArrayList();
    for(String name : variableNames) {
      variables.add(parse(name));
    }
    return new Term(variables);
 
}
TOP

Related Classes of org.renjin.stats.internals.models.FormulaInterpreterTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.