Package cc.redberry.transformation

Source Code of cc.redberry.transformation.GetDerivative1

/*
* Redberry: symbolic tensor computations.
*
* Copyright (c) 2010-2012:
*   Stanislav Poslavsky   <stvlpos@mail.ru>
*   Bolotin Dmitriy       <bolotin.dmitriy@gmail.com>
*
* This file is part of Redberry.
*
* Redberry is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Redberry is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with Redberry. If not, see <http://www.gnu.org/licenses/>.
*/
package cc.redberry.transformation;

import java.util.ArrayList;
import java.util.List;
import cc.redberry.core.context.CC;
import cc.redberry.core.indices.Indices;
import cc.redberry.core.number.ComplexElement;
import cc.redberry.core.tensor.AbstractScalarFunction;
import cc.redberry.core.tensor.Derivative;
import cc.redberry.core.tensor.Product;
import cc.redberry.core.tensor.SimpleTensor;
import cc.redberry.core.tensor.Sum;
import cc.redberry.core.tensor.Tensor;
import cc.redberry.core.tensor.TensorField;
import cc.redberry.core.tensor.TensorIterator;
import cc.redberry.core.tensor.TensorNumber;

/**
*
* @author Dmitry Bolotin
* @author Stanislav Poslavsky
*/
public class GetDerivative1 implements Transformation {
    public static GetDerivative1 INSTANCE = new GetDerivative1();

    private GetDerivative1() {
    }

    public Tensor transform(Tensor tensor) {
        if (!(tensor instanceof Derivative))
            return tensor;
        Derivative derivative = (Derivative) tensor;
        Tensor target = derivative.getTarget().clone();
        for (int i = 0; i < derivative.getDerivativeOrder(); ++i) {
            target = getDerivative(target, derivative.getVariation(i));
            if (target == null)
                return TensorNumber.createZERO();
        }
        return target;
    }

    private Tensor getDerivative(Tensor target, SimpleTensor var) {
        if (target instanceof Sum) {
            Sum sum = (Sum) target;
            TensorIterator it = sum.iterator();
            Tensor current, derivative;
            Sum res = new Sum();
            while (it.hasNext()) {
                current = it.next();
                derivative = getDerivative(current, var);
                if (derivative == null)
                    continue;// it.remove();
                else
                    res.add(derivative);//if (derivative != current)
                //it.set(derivative);
            }
            if (res.isEmpty())
                return null;
            return res.equivalent();
        } else if (target instanceof Product) {
            Product product = (Product) target;
            Tensor derivative;
            List<Tensor> resultProducts = new ArrayList<>();
            for (int i = 0; i < product.size(); ++i) {
                derivative = getDerivative(product.getElements().get(i), var);
                if (derivative == null)
                    continue;
                Product clone = (Product) product.clone();
                clone.getElements().remove(i);
                if (!isOne(derivative))
                    clone.add(derivative);
                resultProducts.add(clone.equivalent());
            }
            if (resultProducts.isEmpty())
                return null;
            if (resultProducts.size() == 1)
                return resultProducts.get(0);
            return new Sum(resultProducts);
        } else if (target.getClass() == SimpleTensor.class) {
            SimpleTensor sp = (SimpleTensor) target;
            if (sp.getName() != var.getName())
                return null;
            if (sp.getIndices().size() == 0)
                return TensorNumber.createONE();
            Product kroneckers = new Product();
            Indices targetIndices = sp.getIndices();
            Indices varIndices = var.getIndices();
            for (int i = 0; i < sp.getIndices().size(); ++i)
                kroneckers.add(CC.createMetricOrKronecker(targetIndices.get(i), varIndices.get(i)));
            return kroneckers.equivalent();
        } else if (target.getClass() == TensorField.class) {
            TensorField field = (TensorField) target;
            Tensor[] args = field.getArgs();
            for (int i = 0; i < args.length; ++i)
                if (getDerivative(args[i], var) != null)
                    return Derivative.createFromInversed(target,var);
            return null;
        } else if (target instanceof AbstractScalarFunction) {
            AbstractScalarFunction func = (AbstractScalarFunction) target;
            Tensor der = getDerivative(func.getInnerTensor(), var);
            if (der == null)
                return null;
            if (isOne(der))
                return func.derivative();
            return new Product(func.derivative(), der);
        }
        //TODO get derivative ot derivative
        return null;
    }

    private static boolean isOne(Tensor t) {
        if (t instanceof TensorNumber && ((TensorNumber) t).getValue().equals(ComplexElement.ONE))
            return true;
        return false;
    }
}
TOP

Related Classes of cc.redberry.transformation.GetDerivative1

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.