Package com.facebook.presto.ml

Source Code of com.facebook.presto.ml.FeatureUnitNormalizer

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.ml;

import com.facebook.presto.ml.type.ModelType;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.SizeOf;
import io.airlift.slice.SliceInput;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;
import it.unimi.dsi.fastutil.ints.Int2DoubleMap;
import it.unimi.dsi.fastutil.ints.Int2DoubleOpenHashMap;

import java.util.HashMap;
import java.util.Map;

/**
* Normalizes features by making every feature value lie in [0, 1].
*/
public class FeatureUnitNormalizer
        extends AbstractFeatureTransformation
{
    private final Int2DoubleMap mins;
    private final Int2DoubleMap maxs;

    public FeatureUnitNormalizer()
    {
        mins = new Int2DoubleOpenHashMap();
        maxs = new Int2DoubleOpenHashMap();

        mins.defaultReturnValue(Double.POSITIVE_INFINITY);
        maxs.defaultReturnValue(Double.NEGATIVE_INFINITY);
    }

    @Override
    public ModelType getType()
    {
        return ModelType.MODEL;
    }

    @Override
    public byte[] getSerializedData()
    {
        // Serialization format is (<key:int><min:double><max:double>)*
        SliceOutput output = Slices.allocate((SizeOf.SIZE_OF_INT + 2 * SizeOf.SIZE_OF_DOUBLE) * mins.size()).getOutput();
        for (int key : mins.keySet()) {
            output.appendInt(key);
            output.appendDouble(mins.get(key));
            output.appendDouble(maxs.get(key));
        }
        return output.slice().getBytes();
    }

    public static FeatureUnitNormalizer deserialize(byte[] modelData)
    {
        SliceInput input = Slices.wrappedBuffer(modelData).getInput();
        FeatureUnitNormalizer model = new FeatureUnitNormalizer();
        while (input.isReadable()) {
            int key = input.readInt();
            model.mins.put(key, input.readDouble());
            model.maxs.put(key, input.readDouble());
        }
        return model;
    }

    @Override
    public void train(Dataset dataset)
    {
        for (FeatureVector vector : dataset.getDatapoints()) {
            for (Map.Entry<Integer, Double> feature : vector.getFeatures().entrySet()) {
                int key = feature.getKey();
                double value = feature.getValue();
                if (value < mins.get(key)) {
                    mins.put(key, value);
                }
                if (value > maxs.get(key)) {
                    maxs.put(key, value);
                }
            }
        }

        for (int key : ImmutableSet.copyOf(mins.keySet())) {
            // Remove any features that had a constant value
            if (mins.get(key) == maxs.get(key)) {
                mins.remove(key);
                maxs.remove(key);
            }
        }
    }

    @Override
    public FeatureVector transform(FeatureVector features)
    {
        Map<Integer, Double> transformed = new HashMap<>();
        for (Map.Entry<Integer, Double> entry : features.getFeatures().entrySet()) {
            int key = entry.getKey();
            double value = entry.getValue();
            if (mins.containsKey(entry.getKey())) {
                double min = mins.get(key);
                value = (value - min) / (maxs.get(key) - min);
            }
            else {
                // Set anything that had a constant value, or was missing, in the training set to zero
                value = 0;
            }
            // In case value is outside of the values seen in the training data, make sure it's [0, 1]
            value = Math.min(1, Math.max(0, value));
            transformed.put(entry.getKey(), value);
        }
        return new FeatureVector(transformed);
    }
}
TOP

Related Classes of com.facebook.presto.ml.FeatureUnitNormalizer

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.