Package hivemall.tools.mapred

Source Code of hivemall.tools.mapred.DistributedCacheLookupUDF

/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2013-2014
*   National Institute of Advanced Industrial Science and Technology (AIST)
*   Registration Number: H25PRO-1520
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
*/
package hivemall.tools.mapred;

import hivemall.ftvec.ExtractFeatureUDF;
import hivemall.utils.collections.OpenHashMap;
import hivemall.utils.hadoop.HadoopUtils;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.IOUtils;

import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
import org.apache.hadoop.hive.serde2.SerDeException;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.Text;

@Description(name = "distcache_gets", value = "_FUNC_(filepath, key, default_value [, parseKey]) - Returns map<key_type, value_type>|value_type")
@UDFType(deterministic = false, stateful = false)
public final class DistributedCacheLookupUDF extends GenericUDF {

    private boolean multipleKeyLookup;
    private boolean multipleDefaultValues;
    private boolean parseKey;
    private Object defaultValue;

    private PrimitiveObjectInspector keyInputOI;
    private PrimitiveObjectInspector valueInputOI;
    private ListObjectInspector keysInputOI;
    private ListObjectInspector valuesInputOI;

    private OpenHashMap<Object, Object> cache;

    @Override
    public ObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException {
        if(argOIs.length != 3 && argOIs.length != 4) {
            throw new UDFArgumentException("Invalid number of arguments for distcache_gets(FILEPATH, KEYS, DEFAULT_VAL, PARSE_KEY): "
                    + argOIs.length + getUsage());
        }
        if(!ObjectInspectorUtils.isConstantObjectInspector(argOIs[2])) {
            throw new UDFArgumentException("Third argument DEFAULT_VALUE must be a constant value: "
                    + TypeInfoUtils.getTypeInfoFromObjectInspector(argOIs[2]));
        }
        if(argOIs.length == 4) {
            this.parseKey = HiveUtils.getConstBoolean(argOIs[3]);
        } else {
            this.parseKey = false;
        }

        String filepath = HiveUtils.getConstString(argOIs[0]);

        ObjectInspector argOI2 = argOIs[2];
        this.multipleDefaultValues = (argOI2.getCategory() == Category.LIST);
        if(multipleDefaultValues) {
            this.valuesInputOI = (ListObjectInspector) argOI2;
            ObjectInspector valuesElemOI = valuesInputOI.getListElementObjectInspector();
            valueInputOI = HiveUtils.asPrimitiveObjectInspector(valuesElemOI);
        } else {
            this.defaultValue = HiveUtils.getConstValue(argOI2);
            valueInputOI = HiveUtils.asPrimitiveObjectInspector(argOI2);
        }
        ObjectInspector valueOutputOI = ObjectInspectorUtils.getStandardObjectInspector(valueInputOI, ObjectInspectorCopyOption.WRITABLE);

        final ObjectInspector outputOI;
        switch(argOIs[1].getCategory()) {
            case PRIMITIVE:
                this.multipleKeyLookup = false;
                this.keyInputOI = (PrimitiveObjectInspector) argOIs[1];
                outputOI = valueOutputOI;
                break;
            case LIST:
                this.multipleKeyLookup = true;
                this.keysInputOI = (ListObjectInspector) argOIs[1];
                ObjectInspector keysElemOI = keysInputOI.getListElementObjectInspector();
                this.keyInputOI = HiveUtils.asPrimitiveObjectInspector(keysElemOI);
                outputOI = ObjectInspectorFactory.getStandardMapObjectInspector(keyInputOI, valueOutputOI);
                break;
            default:
                throw new UDFArgumentException("Unexpected key type: " + argOIs[1].getTypeName());
        }

        if(parseKey && !HiveUtils.isStringOI(keyInputOI)) {
            throw new UDFArgumentException("parseKey=true is only available for string typed key(s)");
        }

        final OpenHashMap<Object, Object> map = new OpenHashMap<Object, Object>(8192);
        try {
            loadValues(map, new File(filepath), keyInputOI, valueInputOI);
            this.cache = map;
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (SerDeException e) {
            throw new RuntimeException(e);
        }

        return outputOI;
    }

    private static void loadValues(OpenHashMap<Object, Object> map, File file, PrimitiveObjectInspector keyOI, PrimitiveObjectInspector valueOI)
            throws IOException, SerDeException {
        if(!file.exists()) {
            return;
        }
        if(!file.getName().endsWith(".crc")) {
            if(file.isDirectory()) {
                for(File f : file.listFiles()) {
                    loadValues(map, f, keyOI, valueOI);
                }
            } else {
                LazySimpleSerDe serde = HiveUtils.getKeyValueLineSerde(keyOI, valueOI);
                StructObjectInspector lineOI = (StructObjectInspector) serde.getObjectInspector();
                StructField keyRef = lineOI.getStructFieldRef("key");
                StructField valueRef = lineOI.getStructFieldRef("value");
                PrimitiveObjectInspector keyRefOI = (PrimitiveObjectInspector) keyRef.getFieldObjectInspector();
                PrimitiveObjectInspector valueRefOI = (PrimitiveObjectInspector) valueRef.getFieldObjectInspector();

                BufferedReader reader = null;
                try {
                    reader = HadoopUtils.getBufferedReader(file);
                    String line;
                    while((line = reader.readLine()) != null) {
                        Text lineText = new Text(line);
                        Object lineObj = serde.deserialize(lineText);
                        List<Object> fields = lineOI.getStructFieldsDataAsList(lineObj);
                        Object f0 = fields.get(0);
                        Object f1 = fields.get(1);
                        Object k = keyRefOI.getPrimitiveJavaObject(f0);
                        Object v = valueRefOI.getPrimitiveWritableObject(valueRefOI.copyObject(f1));
                        map.put(k, v);
                    }
                } finally {
                    IOUtils.closeQuietly(reader);
                }
            }
        }
    }

    @Override
    public Object evaluate(DeferredObject[] args) throws HiveException {
        final Object arg1 = args[1].get();
        if(multipleKeyLookup) {
            if(multipleDefaultValues) {
                Object arg2 = args[2].get();
                return gets(arg1, arg2);
            } else {
                return gets(arg1);
            }
        } else {
            return get(arg1);
        }
    }

    private Object get(Object arg) {
        Object key = keyInputOI.getPrimitiveJavaObject(arg);
        Object value = cache.get(lookupKey(key));
        return (value == null) ? defaultValue : value;
    }

    private Map<Object, Object> gets(Object arg) {
        List<?> keys = keysInputOI.getList(arg);

        final Map<Object, Object> map = new HashMap<Object, Object>();
        for(Object k : keys) {
            if(k == null) {
                continue;
            }
            Object kj = keyInputOI.getPrimitiveJavaObject(k);
            final Object v = cache.get(lookupKey(kj));
            if(v == null) {
                map.put(k, defaultValue);
            } else {
                map.put(k, v);
            }
        }
        return map;
    }

    private Map<Object, Object> gets(Object argKeys, Object argValues) throws HiveException {
        final List<?> keys = keysInputOI.getList(argKeys);
        final List<?> defaultValues = valuesInputOI.getList(argValues);
        final int numKeys = keys.size();
        if(numKeys != defaultValues.size()) {
            throw new HiveException("# of default values != # of lookup keys: keys " + argKeys
                    + ", values: " + argValues);
        }

        final Map<Object, Object> map = new HashMap<Object, Object>();
        for(int i = 0; i < numKeys; i++) {
            Object k = keys.get(i);
            if(k == null) {
                continue;
            }
            Object kj = keyInputOI.getPrimitiveJavaObject(k);
            Object v = cache.get(lookupKey(kj));
            if(v == null) {
                v = defaultValues.get(i);
                if(v != null) {
                    v = valueInputOI.getPrimitiveWritableObject(valueInputOI.copyObject(v));
                }
            }
            map.put(k, v);
        }
        return map;
    }

    private Object lookupKey(final Object key) {
        if(parseKey) {
            String keyStr = key.toString();
            return ExtractFeatureUDF.extractFeature(keyStr);
        } else {
            return key;
        }
    }

    @Override
    public String getDisplayString(String[] args) {
        return "distcache_gets()";
    }

    private static String getUsage() {
        return "\nUSAGE: "
                + "\n\tdistcache_gets(const string FILEPATH, object[] keys, const object defaultValue [, const boolean parseKey])::map<key_type, value_type>"
                + "\n\tdistcache_gets(const string FILEPATH, object key, const object defaultValue [, const boolean parseKey])::value_type"
                + "\n\tdistcache_gets(const string FILEPATH, object[] key, object[] defaultValues [, const boolean parseKey])::map<key_type, value_type>";
    }

}
TOP

Related Classes of hivemall.tools.mapred.DistributedCacheLookupUDF

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.