{
debug.println(2, "training on "+train.size()+" train data and "+test.size()+" test data");
//first training
debug.print(3, "first training ");
svm = new DoublePegasosSVM();
svm.setLambda(lambda);
svm.setK(k);
svm.setT(T);
svm.setT0(t0);
svm.train(train);
debug.println(3, " done.");
//affect numplus highest output to plus class
debug.println(3, "affecting 1 to the "+numplus+" highest output");
SortedSet<TrainingSample<double[]>> sorted = new TreeSet<TrainingSample<double[]>>(new Comparator<TrainingSample<double[]>>(){
@Override
public int compare(TrainingSample<double[]> o1, TrainingSample<double[]> o2) {
int ret = (new Double(svm.valueOf(o2.sample))).compareTo(svm.valueOf(o1.sample));
if(ret == 0)
ret = -1;
return ret;
}
});
sorted.addAll(test);
debug.println(3, "sorted size : "+sorted.size()+" test size : "+test.size());
int n = 0;
for(TrainingSample<double[]> t : sorted)
{
if(n <= numplus)
t.label = 1;
else
t.label = -1;
n++;
}
double C = 1. / (train.size()*lambda) ;
double Cminus = 1e-5;
double Cplus = 1e-5 * numplus/(test.size() - numplus);
while(Cminus < C || Cplus < C)
{
//solve full problem
ArrayList<TrainingSample<double[]>> full = new ArrayList<TrainingSample<double[]>>();
full.addAll(train);
full.addAll(test);
debug.print(3, "full training ");
svm = new DoublePegasosSVM();
svm.setLambda(lambda);
svm.setK(k);
svm.setT(T);
svm.setT0(t0);
svm.train(full);
debug.println(3, "done.");
boolean changed = false;
do
{
changed = false;
//0. computing error
final Map<TrainingSample<double[]>, Double> errorCache = new HashMap<TrainingSample<double[]>, Double>();
for(TrainingSample<double[]> t : test)
{
double err1 = 1. - t.label * svm.valueOf(t.sample);
errorCache.put(t, err1);
}
debug.println(3, "Error cache done.");
// 1 . sort by descending error
sorted = new TreeSet<TrainingSample<double[]>>(new Comparator<TrainingSample<double[]>>(){
@Override
public int compare(TrainingSample<double[]> o1,
TrainingSample<double[]> o2) {
int ret = errorCache.get(o2).compareTo(errorCache.get(o1));
if(ret == 0)
ret = -1;
return ret;
}
});
sorted.addAll(test);
List<TrainingSample<double[]>> sortedList = new ArrayList<TrainingSample<double[]>>();
sortedList.addAll(sorted);
debug.println(3, "sorting done, checking couple");
// 2 . test all couple by decreasing error order
// for(TrainingSample<T> i1 : sorted)
for(int i = 0 ; i < sortedList.size(); i++)
{
TrainingSample<double[]> i1 = sortedList.get(i);
// for(TrainingSample<T> i2 : sorted)
for(int j = i+1; j < sortedList.size(); j++)
{
TrainingSample<double[]> i2 = sortedList.get(j);
if(examine(i1, i2, errorCache))
{
debug.println(3, "couple found !");
changed = true;
break;
}
}
if(changed)
break;
}
if(changed)
{
debug.println(3, "re-training");
svm = new DoublePegasosSVM();
svm.setLambda(lambda);
svm.setK(k);
svm.setT(T);
svm.setT0(t0);
svm.train(full);