package edu.neu.ccs.task;
import java.io.InputStream;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.xml.namespace.QName;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLInputFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.Attribute;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;
import javax.xml.transform.stream.StreamSource;
import edu.neu.ccs.task.dialogue.Display;
import edu.neu.ccs.task.dialogue.Input;
import edu.neu.ccs.task.dialogue.MenuUserAction;
import edu.neu.ccs.task.dialogue.Output;
import edu.neu.ccs.task.dialogue.Result;
import edu.neu.ccs.task.dialogue.TextUserAction;
import edu.neu.ccs.task.dialogue.Turn;
import edu.neu.ccs.task.dialogue.WidgetUserAction;
public class TaskModelXmlReader {
private static final String CETASK_NS = "http://ce.org/cea-2018";
private static final String DTASK_NS = "http://ccs.neu.edu/research/rag/task/dialogue";
private static final QName E_AGENT = new QName(DTASK_NS, "agent");
private static final QName E_AGENTALT = new QName(DTASK_NS, "agentalt");
private static final QName E_APPLICABLE = new QName(CETASK_NS, "applicable");
private static final QName E_BELIEF_UPDATE = new QName(DTASK_NS, "belief-update");
private static final QName E_BINDING = new QName(CETASK_NS, "binding");
private static final QName E_INPUT = new QName(CETASK_NS, "input");
private static final QName E_OUTPUT = new QName(CETASK_NS, "output");
private static final QName E_PRECONDITION = new QName(CETASK_NS, "precondition");
private static final QName E_POSTCONDITION = new QName(CETASK_NS, "postcondition");
private static final QName E_PRIORITY = new QName(DTASK_NS, "priority");
private static final QName E_REPEATREQ = new QName(DTASK_NS, "repeatreq");
private static final QName E_RESULT = new QName(DTASK_NS, "result");
private static final QName E_SAY = new QName(DTASK_NS, "say");
private static final QName E_SCRIPT = new QName(CETASK_NS, "script");
private static final QName E_STEP = new QName(CETASK_NS, "step");
private static final QName E_SUBTASKS = new QName(CETASK_NS, "subtasks");
private static final QName E_TASK = new QName(CETASK_NS, "task");
private static final QName E_TURN = new QName(DTASK_NS, "turn");
private static final QName E_USER = new QName(DTASK_NS, "user");
private static final QName E_TEXTINPUT = new QName(DTASK_NS, "textinput");
private static final QName E_WIDGETINPUT = new QName(DTASK_NS, "widgetinput");
private static final QName E_PARAMETER = new QName(DTASK_NS, "parameter");
private static final QName E_DISPLAY = new QName(DTASK_NS, "display");
private static final QName A_ABOUT = new QName("about");
private static final QName A_APPLICABLE = new QName("applicable");
//private static final QName A_GLOBAL = new QName("global");
private static final QName A_GOAL = new QName("goal");
private static final QName A_ID = new QName("id");
private static final QName A_INIT = new QName("init");
private static final QName A_MAX_OCCURS = new QName("maxOccurs");
private static final QName A_MIN_OCCURS = new QName("minOccurs");
private static final QName A_NAME = new QName("name");
private static final QName A_ORDERED = new QName("ordered");
private static final QName A_PRIORITY = new QName(DTASK_NS, "priority");
private static final QName A_REQUIRED = new QName("required");
private static final QName A_SUFFICIENT = new QName("sufficient");
private static final QName A_SLOT = new QName("slot");
private static final QName A_TASK = new QName("task");
private static final QName A_TYPE = new QName("type");
private static final QName A_VALUE = new QName("value");
private static final QName A_WIDGET = new QName("widget");
private static final QName A_URL = new QName("url");
private static final QName A_AUTHENTICATION = new QName("authentication");
private static final QName A_SIZE = new QName("size");
private final TaskModelSet modelSet;
private final StreamSource source;
private String uri;
public TaskModelXmlReader(TaskModelSet modelSet, StreamSource source) {
this.modelSet = modelSet;
this.source = source;
}
public TaskModelXmlReader(TaskModelSet modelSet, InputStream in, String location) {
this(modelSet, new StreamSource(in, location));
}
public TaskModel read() throws XMLStreamException {
XMLInputFactory xif = XMLInputFactory.newInstance();
XMLEventReader er = xif.createXMLEventReader(source);
XMLEvent e;
do {
e = er.nextEvent();
} while (! e.isStartElement());
return taskModel(e.asStartElement(), er);
}
private TaskModel taskModel(StartElement e, XMLEventReader er)
throws XMLStreamException {
uri = e.getAttributeByName(A_ABOUT).getValue();
List<Type> types = new ArrayList<Type>();
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
QName name = nextE.getName();
if (name.equals(E_TASK))
types.addAll(task(nextE, er));
else if (name.equals(E_SUBTASKS))
types.add(subtasks(nextE, er));
else if (name.equals(E_SCRIPT))
types.add(script(nextE, er));
else if (name.equals(E_TURN))
types.add(turn(nextE, er));
else if (name.equals(E_BELIEF_UPDATE))
types.add(beliefUpdate(nextE, er));
else
skipElement(er);
}
return new TaskModel(uri, source.getSystemId(), types);
}
private List<Type> task(StartElement e, XMLEventReader er)
throws XMLStreamException {
List<Type> types = new ArrayList<Type>();
TaskType.Builder b = new TaskType.Builder();
typeBase(e, b);
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
QName name = nextE.getName();
if (name.equals(E_INPUT))
b.inputs.add(slot(nextE, er, true));
else if (name.equals(E_OUTPUT))
b.outputs.add(slot(nextE, er, false));
else if (name.equals(E_PRECONDITION))
b.precondition = er.getElementText();
else if (name.equals(E_POSTCONDITION)) {
b.postconditionSufficient = attribAsBool(nextE, A_SUFFICIENT, false);
b.postcondition = er.getElementText();
} else if (name.equals(E_SUBTASKS))
types.add(subtasks(nextE, er, b.name));
else if (name.equals(E_SCRIPT))
types.add(script(nextE, er, b.name));
else if (name.equals(E_TURN))
types.add(turn(nextE, er, b.name));
else if (name.equals(E_BELIEF_UPDATE))
types.add(beliefUpdate(nextE, er, b.name));
else
skipElement(er);
}
types.add(new TaskType(b));
return types;
}
private Slot slot(StartElement e, XMLEventReader er, boolean input)
throws XMLStreamException {
String name = e.getAttributeByName(A_NAME).getValue();
String type = e.getAttributeByName(A_TYPE).getValue();
skipElement(er);
return new Slot(name, type, input, false);
}
private DecompositionType subtasks(StartElement e, XMLEventReader er)
throws XMLStreamException {
return subtasks(e, er, attribAsQName(e, A_GOAL));
}
private DecompositionType subtasks(
StartElement e, XMLEventReader er,
QName goal) throws XMLStreamException {
DecompositionType.Builder b = new DecompositionType.Builder();
typeBase(e, b);
b.goal = goal;
b.ordered = attribAsBool(e, A_ORDERED, true);
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
QName name = nextE.getName();
if (name.equals(E_STEP))
b.steps.add(step(nextE, er, getRequiredStep(b)));
else if (name.equals(E_APPLICABLE))
b.applicable = er.getElementText();
else if (name.equals(E_PRIORITY))
b.priority = Integer.parseInt(er.getElementText());
else if (name.equals(E_BINDING))
b.bindings.add(binding(nextE, er));
else
skipElement(er);
}
return new DecompositionType(b);
}
private String getRequiredStep(DecompositionType.Builder b) {
if (!b.ordered || b.steps.isEmpty())
return null;
return b.steps.get(b.steps.size()-1).getName();
}
private StepDescription step(
StartElement e, XMLEventReader er,
String impliedRequired) throws XMLStreamException {
String name = e.getAttributeByName(A_NAME).getValue();
QName type = attribAsQName(e, A_TASK);
List<String> required;
if (impliedRequired != null)
required = Collections.singletonList(impliedRequired);
else {
Attribute requiredA = e.getAttributeByName(A_REQUIRED);
if (requiredA == null)
required = Collections.emptyList();
else
required = Arrays.asList(requiredA.getValue().split("\\s+"));
}
int minOccurs = attribAsInt(e, A_MIN_OCCURS, 1);
int maxOccurs;
Attribute maxOccursA = e.getAttributeByName(A_MAX_OCCURS);
if (maxOccursA == null)
maxOccurs = 1;
else if (maxOccursA.getValue().equals("unbounded"))
maxOccurs = Integer.MAX_VALUE;
else
maxOccurs = Integer.parseInt(maxOccursA.getValue());
skipElement(er);
return new StepDescription(name, type, required, minOccurs, maxOccurs);
}
private static final Pattern REF_P = Pattern.compile(
"\\$([a-zA-Z][a-zA-Z0-9_]*)\\.([a-zA-Z][a-zA-Z0-9_]*)");
private Binding binding(StartElement e, XMLEventReader er)
throws XMLStreamException {
Matcher m = REF_P.matcher(e.getAttributeByName(A_SLOT).getValue());
if ( ! m.matches())
throw new XMLStreamException("bad slot", e.getLocation());
String step = m.group(1);
String slot = m.group(2);
String value = e.getAttributeByName(A_VALUE).getValue();
skipElement(er);
return new Binding(step, slot, value);
}
private Script script(StartElement e, XMLEventReader er)
throws XMLStreamException {
Attribute taskA = e.getAttributeByName(A_TASK);
QName task = taskA==null ? null : asQName(e, taskA.getValue());
return script(e, er, task);
}
private Script script(StartElement e, XMLEventReader er, QName task)
throws XMLStreamException {
Script.Builder b = new Script.Builder();
typeBase(e, b);
b.task = task;
b.init = attribAsBool(e, A_INIT, false);
b.applicable = attribAsString(e, A_APPLICABLE, null);
b.priority = attribAsInt(e, A_PRIORITY, 0);
b.text = er.getElementText();
return new Script(b);
}
private BeliefUpdate beliefUpdate(StartElement e, XMLEventReader er)
throws XMLStreamException {
Attribute taskA = e.getAttributeByName(A_TASK);
QName task = taskA==null ? null : asQName(e, taskA.getValue());
return beliefUpdate(e, er, task);
}
private BeliefUpdate beliefUpdate(StartElement e, XMLEventReader er, QName task)
throws XMLStreamException {
BeliefUpdate.Builder b = new BeliefUpdate.Builder();
typeBase(e, b);
b.task = task;
b.text = er.getElementText();
return new BeliefUpdate(b);
}
private Turn turn(StartElement e, XMLEventReader er)
throws XMLStreamException {
return turn(e, er, attribAsQName(e, A_TASK));
}
private Turn turn(StartElement e, XMLEventReader er, QName task)
throws XMLStreamException {
Turn.Builder b = new Turn.Builder();
typeBase(e, b);
b.task = task;
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
QName name = nextE.getName();
if (name.equals(E_APPLICABLE))
b.applicable = er.getElementText();
else if (name.equals(E_PRIORITY))
b.priority = Integer.parseInt(er.getElementText());
else if (name.equals(E_AGENT))
b.outputs.add(agent(nextE, er));
else if (name.equals(E_AGENTALT))
b.altOutputs.add(agent(nextE, er));
else if (name.equals(E_USER) || name.equals(E_REPEATREQ))
b.userAction = menuUserAction(nextE, er);
else if (name.equals(E_TEXTINPUT))
b.userAction = textUserAction(nextE, er);
else if (name.equals(E_WIDGETINPUT))
b.userAction = widgetUserAction(nextE, er);
else if (name.equals(E_RESULT))
b.bindings.add(result(nextE, er));
else if (name.equals(E_SCRIPT))
b.scripts.add(script(nextE, er, null));
else if (name.equals(E_DISPLAY)) {
String url = nextE.getAttributeByName(A_URL).getValue();
Attribute typeA = nextE.getAttributeByName(A_TYPE);
String type = (typeA==null) ? null : typeA.getValue();
Attribute sizeA = nextE.getAttributeByName(A_SIZE);
String size = (sizeA==null) ? null : sizeA.getValue();
Attribute authA = nextE.getAttributeByName(A_AUTHENTICATION);
boolean auth = (authA!=null) && Boolean.parseBoolean(authA.getValue());
b.display = new Display(url, type, size, auth);
skipElement(er);
} else
skipElement(er);
}
return new Turn(b);
}
private Output agent(StartElement e, XMLEventReader er)
throws XMLStreamException {
StringBuilder plain = new StringBuilder();
StringWriter annotated = new StringWriter();
int depth = 1;
while (true) {
XMLEvent next = er.nextEvent();
if (next.isEndElement())
if (--depth <= 0)
break;
if (next.isCharacters())
plain.append(next.asCharacters().getData());
if (next.isStartElement())
depth++;
next.writeAsEncodedUnicode(annotated);
}
return new Output(
plain.toString().trim().replaceAll("\\s+", " "),
annotated.toString().trim().replaceAll("\\s+", " "));
}
private MenuUserAction menuUserAction(StartElement e, XMLEventReader er)
throws XMLStreamException {
List<Input> inputs = new ArrayList<Input>();
List<String> repeatTexts = new ArrayList<String>();
QName name = e.getName();
do {
if (name.equals(E_USER))
inputs.add(user(e, er));
else if (name.equals(E_REPEATREQ))
repeatTexts.add(er.getElementText().trim().replaceAll("\\s+", " "));
XMLEvent next = peekNextTag(er);
if (next.isEndElement())
break;
e = next.asStartElement();
name = e.getName();
if (!(name.equals(E_USER) || name.equals(E_REPEATREQ)))
break;
er.nextEvent(); // actually advance now
} while (true);
return new MenuUserAction(inputs, repeatTexts);
}
private Input user(StartElement e, XMLEventReader er)
throws XMLStreamException {
List<String> texts = new ArrayList<String>();
List<Result> bindings = new ArrayList<Result>();
List<Script> scripts = new ArrayList<Script>();
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
QName name = nextE.getName();
if (name.equals(E_SAY))
texts.add(er.getElementText().trim().replaceAll("\\s+", " "));
else if (name.equals(E_RESULT))
bindings.add(result(nextE, er));
else if (name.equals(E_SCRIPT))
scripts.add(script(nextE, er, null));
else
skipElement(er);
}
return new Input(texts, bindings, scripts);
}
private Result result(StartElement e, XMLEventReader er)
throws XMLStreamException {
String slot = e.getAttributeByName(A_SLOT).getValue();
String value = e.getAttributeByName(A_VALUE).getValue();
skipElement(er);
return new Result(slot, value);
}
private TextUserAction textUserAction(StartElement e, XMLEventReader er)
throws XMLStreamException {
String slot = e.getAttributeByName(A_SLOT).getValue();
String prompt = er.getElementText().trim().replaceAll("\\s+", " ");
return new TextUserAction(prompt, slot);
}
private WidgetUserAction widgetUserAction(StartElement e, XMLEventReader er)
throws XMLStreamException {
String slot = e.getAttributeByName(A_SLOT).getValue();
String url = e.getAttributeByName(A_WIDGET).getValue();
Map<String, String> params = new HashMap<String, String>();
XMLEvent next;
while (!(next = nextTag(er)).isEndElement()) {
StartElement nextE = next.asStartElement();
if (nextE.getName().equals(E_PARAMETER)) {
String name = nextE.getAttributeByName(A_NAME).getValue();
String value = er.getElementText();
params.put(name, value);
} else
skipElement(er);
}
return new WidgetUserAction(url, slot, params);
}
private void typeBase(StartElement e, Type.Builder b) {
b.modelSet = modelSet;
b.location = new Location(
e.getLocation().getLineNumber(),
e.getLocation().getColumnNumber());
b.name = new QName(uri, attribAsString(e, A_ID, ""));
}
private void skipElement(XMLEventReader er) throws XMLStreamException {
for (int depth=1; depth>0;) {
XMLEvent next = er.nextEvent();
if (next.isStartElement())
depth++;
else if (next.isEndElement())
depth--;
}
}
private QName asQName(StartElement e, String name) {
String[] parts = name.split(":", 2);
if (parts.length == 1)
return new QName(uri, name); // no prefix
return new QName(e.getNamespaceURI(parts[0]), parts[1], parts[0]);
}
private QName attribAsQName(StartElement e, QName name) {
return asQName(e, e.getAttributeByName(name).getValue());
}
private boolean attribAsBool(StartElement e, QName name, boolean implied) {
Attribute a = e.getAttributeByName(name);
return (a==null) ? implied : Boolean.parseBoolean(a.getValue());
}
private int attribAsInt(StartElement e, QName name, int implied) {
Attribute a = e.getAttributeByName(name);
return (a==null) ? implied : Integer.parseInt(a.getValue());
}
private String attribAsString(StartElement e, QName name, String implied) {
Attribute a = e.getAttributeByName(name);
return (a==null) ? implied : a.getValue();
}
// XMLEventReader.nextTag() is buggy (in the Sun reference implementation)
// This version seems to work correctly:
private XMLEvent nextTag(XMLEventReader er) throws XMLStreamException {
XMLEvent e;
do {
e = er.nextEvent();
if (e.isCharacters() && !e.asCharacters().isWhiteSpace())
throw new XMLStreamException(
"Unexpected text content",
e.getLocation());
} while (!(e.isStartElement() || e.isEndElement()));
return e;
}
// if the next tag (as in nextTag()) is a start element, advance stream and
// return it.
// if it's an end element, return null
private XMLEvent peekNextTag(XMLEventReader er) throws XMLStreamException {
XMLEvent e;
do {
e = er.peek();
if (e.isCharacters() && !e.asCharacters().isWhiteSpace())
throw new XMLStreamException(
"Unexpected text content",
e.getLocation());
if (e.isStartElement() || e.isEndElement())
return e;
e = er.nextEvent(); // actually advance now
} while (true);
}
}