Package org.grouplens.lenskit.scored

Source Code of org.grouplens.lenskit.scored.PackedScoredIdList

/*
* LensKit, an open source recommender systems toolkit.
* Copyright 2010-2014 LensKit Contributors.  See CONTRIBUTORS.md.
* Work on LensKit has been funded by the National Science Foundation under
* grants IIS 05-34939, 08-08692, 08-12148, and 10-17697.
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as
* published by the Free Software Foundation; either version 2.1 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program; if not, write to the Free Software Foundation, Inc., 51
* Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
*/
package org.grouplens.lenskit.scored;

import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicates;
import com.google.common.collect.Collections2;
import com.google.common.collect.FluentIterable;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import it.unimi.dsi.fastutil.longs.LongList;
import org.grouplens.lenskit.collections.FastCollection;
import org.grouplens.lenskit.symbols.DoubleSymbolValue;
import org.grouplens.lenskit.symbols.Symbol;
import org.grouplens.lenskit.symbols.SymbolValue;
import org.grouplens.lenskit.symbols.TypedSymbol;

import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import java.io.IOException;
import java.io.InvalidObjectException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.util.*;

/**
* A space-efficient, unboxed list of scored IDs.  These lists are immutable; create them using a
* {@link ScoredIdListBuilder} (see {@link ScoredIds#newListBuilder()}.
*
* @since 1.4
* @compat Public
* @author <a href="http://www.grouplens.org">GroupLens Research</a>
*/
@SuppressWarnings("deprecation")
public final class PackedScoredIdList extends AbstractList<ScoredId> implements FastCollection<ScoredId>, Serializable {
    private static final long serialVersionUID = 1L;
    private final LongList ids;
    private final DoubleList scores;
    private final Map<Symbol, DoubleList> unboxedChannels;
    private final Map<TypedSymbol<?>, List<?>> channels;

    PackedScoredIdList(LongList ids, DoubleList scores,
                       Map<TypedSymbol<?>, List<?>> chans,
                       Map<Symbol, DoubleList> unboxedChans) {
        assert ids.size() == scores.size();
        this.ids = ids;
        this.scores = scores;
        unboxedChannels = unboxedChans;
        channels = chans;
    }

    /**
     * Do some light validation of the scored ID list.
     * @param in The input stream
     */
    private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
        in.defaultReadObject();
        if (ids.size() != scores.size()) {
            throw new InvalidObjectException("score array has incorrect size");
        }
        for (DoubleList chan: unboxedChannels.values()) {
            if (chan.size() != ids.size()) {
                throw new InvalidObjectException("channel array has incorrect size");
            }
        }
        for (Map.Entry<TypedSymbol<?>,List<?>> tc: channels.entrySet()) {
            if (tc.getValue().size() != ids.size()) {
                throw new InvalidObjectException("channel array has incorrect size");
            }
        }
    }

    @Override
    public int size() {
        return ids.size();
    }

    @Override
    @Deprecated
    public Iterator<ScoredId> fastIterator() {
        return iterator();
    }

    @Override
    public ScoredId get(int i) {
        Preconditions.checkElementIndex(i, size());
        return getFlyweight(i);
    }

    /**
     * Get a flyweight id from the list, at the specified index.
     * @param i The index.  No validation is performed of the index; the scored id will fail if the
     *          index is out-of-bounds.  This is to allow other classes in the package (such as the
     *          builder) to have unlimited access to scored ids.
     * @return A flyweight at the specified position.
     */
    IndirectScoredId getFlyweight(int i) {
        return new IndirectScoredId(i);
    }

    /**
     * Flyweight implementation of {@link ScoredId} backed by the list's storage.
     */
    class IndirectScoredId extends AbstractScoredId implements Serializable {
        private static final long serialVersionUID = 1L;

        private int index;

        public IndirectScoredId(int idx) {
            index = idx;
        }

        public Object writeReplace() {
            return ScoredIds.copyBuilder(this).build();
        }

        public void setIndex(int idx) {
            index = idx;
        }

        @Override
        public long getId() {
            return ids.get(index);
        }

        @Override
        public double getScore() {
            return scores.get(index);
        }

        @Override
        public Set<Symbol> getUnboxedChannelSymbols() {
            return unboxedChannels.keySet();
        }

        @Override
        public Set<TypedSymbol<?>> getChannelSymbols() {
            ImmutableSet.Builder<TypedSymbol<?>> bld = ImmutableSet.builder();
            for (Symbol s: unboxedChannels.keySet()) {
                bld.add(s.withType(Double.class));
            }
            for (Map.Entry<TypedSymbol<?>, List<?>> e: channels.entrySet()) {
                if (e.getValue().get(index) != null) {
                    bld.add(e.getKey());
                }
            }
            return bld.build();
        }

        @Nonnull
        @Override
        public Collection<SymbolValue<?>> getChannels() {
            // FIXME Make this fast
            List<SymbolValue<?>> channels = Lists.newArrayList();
            FluentIterable.from(PackedScoredIdList.this.channels.entrySet())
                    .transform(new Function<Map.Entry<TypedSymbol<?>, List<?>>, SymbolValue<?>>() {
                        @SuppressWarnings({"rawtypes", "unchecked"})
                        @Nullable
                        @Override
                        public SymbolValue<?> apply(@Nullable Map.Entry<TypedSymbol<?>, List<?>> input) {
                            assert input != null;
                            Object obj = input.getValue().get(index);
                            if (obj != null) {
                                TypedSymbol sym = input.getKey();
                                assert sym.getType().isInstance(obj);
                                return sym.withValue(obj);
                            } else {
                                return null;
                            }
                        }
                    }).filter(Predicates.notNull())
                    .copyInto(channels);
            channels.addAll(getUnboxedChannels());
            return channels;
        }

        @Nonnull
        @Override
        public Collection<DoubleSymbolValue> getUnboxedChannels() {
            return Collections2.transform(
                    unboxedChannels.entrySet(),
                    new Function<Map.Entry<Symbol, DoubleList>, DoubleSymbolValue>() {
                        @Nullable
                        @Override
                        public DoubleSymbolValue apply(@Nullable Map.Entry<Symbol, DoubleList> input) {
                            assert input != null;
                            return SymbolValue.of(input.getKey(), input.getValue().get(index));
                        }
                    });
        }

        @Nullable
        @Override
        public <T> T getChannelValue(@Nonnull TypedSymbol<T> sym) {
            if (sym.getType().equals(Double.class) && hasUnboxedChannel(sym.getRawSymbol())) {
                return sym.getType().cast(getUnboxedChannelValue(sym.getRawSymbol()));
            } else {
                List<?> array = channels.get(sym);
                if (array == null) {
                    return null;
                } else {
                    return sym.getType().cast(array.get(index));
                }
            }
        }

        @Override
        public double getUnboxedChannelValue(Symbol sym) {
            DoubleList array = unboxedChannels.get(sym);
            if (array != null) {
                return array.get(index);
            } else {
                throw new NullPointerException("no symbol " + sym);
            }
        }

        @Override
        public boolean hasUnboxedChannel(Symbol s) {
            return unboxedChannels.containsKey(s);
        }

        @Override
        public boolean hasChannel(TypedSymbol<?> s) {
            if (s.getType().equals(Double.class) && hasUnboxedChannel(s.getRawSymbol())) {
                return true;
            } else {
                List<?> obj = channels.get(s);
                return obj != null && obj.get(index) != null;
            }
        }
    }
}
TOP

Related Classes of org.grouplens.lenskit.scored.PackedScoredIdList

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.