BZDEV: Possible solution to the e.t. "scalar" problem

From: Todd Veldhuizen (tveldhui@oonumerics.org)
Date: Thu Jun 18 1998 - 17:40:46 EST


I think I might have a solution to the pesky "scalar" problem
with expression templates. Disclaimer: I'm not sure if this
would work in an industrial-strength e.t. implementation yet.

The basic trick is to have a traits class which maps from types
onto their appropriate expression template type, e.g.

Array -> ArrayIterator
Expr<T> -> Expr<T> (it's already an e.t. type)
IndexPlaceholder<N> -> IndexPlaceholder<N> (ditto)

The tricky part is that the default trait can be:

T -> ExprScalar<T>

so that any type which isn't recognized is assumed to be a scalar.
In the code below, this traits class is called asExpr<T>.

Then, both Array and Expr<T> inherit from a dummy template base class
a la Furnish (or a la "Barton and Nackman trick"):

class Array : KludgeBase<Array> {
};

template<class T>
class Expr<T> : KludgeBase<Expr<T> > {
};

Only 4 versions of each operator have to be provided:

operator+(Array, T)
operator+(Expr<>, T)
operator+(T, KludgeBase<Array>)
operator+(T, KludgeBase<Expr<>>)

The KludgeBase disambiguates between the two possible matches for
Array+Array, because of partial ordering rules.

Sketchy code follows.

#include <iostream.h>

template<class T>
class KludgeBase { };

template<class T>
struct Expr : public KludgeBase<Expr<T> > {
    T expr_;

    Expr(T x)
      : expr_(x)
    { }

    template<class T1, class T2>
    Expr(T1 x, T2 y)
      : expr_(x,y)
    { }

    double operator[](int i)
    { return expr_[i]; }
};

template<class T_op, class T1, class T2>
struct BinExpr {
    T1 leftNode_;
    T2 rightNode_;

    template<class T_init1, class T_init2>
    BinExpr(T_init1 t1, T_init2 t2)
        : leftNode_(t1), rightNode_(t2)
    { }

    double operator[](int i)
    { return T_op::apply(leftNode_[i],rightNode_[i]); }
};

template<class T_numtype>
struct ExprScalar {
    T_numtype value_;

    ExprScalar(T_numtype value)
        : value_(value)
    { }

    double operator[](int)
    { return value_; }
};

struct Array : public KludgeBase<Array> {
    typedef Array T_iterator;

    Array(double* data, int N)
        : data_(data), N_(N)
    { }

    template<class T_expr>
    void operator=(Expr<T_expr> expression)
    {
        for (int i=0; i < N_; ++i)
            data_[i] = expression[i];
    }

    double operator[](int i)
    { return data_[i]; }

    void dump()
    {
        for (int i=0; i < N_; ++i)
            cout << data_[i] << " ";
        cout << endl;
    }

    double* data_;
    int N_;
};

struct plus {
public:
    static double apply(double a, double b) {
        return a+b;
    };
};

template<class T>
struct asExpr {
    typedef ExprScalar<T> T_expr;
};

template<class T>
struct asExpr<Expr<T> > {
    typedef Expr<T> T_expr;
};

template<>
struct asExpr<Array> {
    typedef Array::T_iterator T_expr;
};

template<class T2>
Expr<BinExpr<plus, Array::T_iterator, typename asExpr<T2>::T_expr> >
operator+(Array x, T2 y)
{
    return Expr<BinExpr<plus,Array::T_iterator,
        typename asExpr<T2>::T_expr> >(x,y);
}

template<class T_expr, class T2>
Expr<BinExpr<plus, Expr<T_expr>, typename asExpr<T2>::T_expr> >
operator+(Expr<T_expr> x, T2 y)
{
    return Expr<BinExpr<plus, Expr<T_expr>, typename asExpr<T2>::T_expr> >
        (x,y);
}

template<class T1>
Expr<BinExpr<plus, typename asExpr<T1>::T_expr,
    Array::T_iterator> >
operator+(T1 x, KludgeBase<Array>& y)
{
    return Expr<BinExpr<plus, typename asExpr<T1>::T_expr,
        Array::T_iterator> >(x,(Array&)y);
}

template<class T1, class T_expr>
Expr<BinExpr<plus, typename asExpr<T1>::T_expr,
    Expr<T_expr> > >
operator+(T1 x, KludgeBase<Expr<T_expr> >& y)
{
    return Expr<BinExpr<plus, typename asExpr<T1>::T_expr,
        Expr<T_expr> > >(x,(Expr<T_expr>&)y);
}

int main()
{
    double a_data[4] = { 1, 2, 3, 4};
    double b_data[4] = { 5, 6, 7, 8};
    double c_data[4];

    Array A(a_data,4), B(b_data,4), C(c_data,4);
    C = A + 5; C.dump();
    C = A + 5.0; C.dump();
    C = A + B; C.dump();
    C = 5.0f + B; C.dump();

    C = 5L + A + (A + B) + 3.0f; C.dump();

    return 0;
}

--------------------- blitz-dev list --------------------------------
* To subscribe/unsubscribe: mail to majordomo@oonumerics.org, with
"subscribe blitz-dev" or "unsubscribe blitz-dev" in the body of the message
* Blitz++ web page: http://oonumerics.org/blitz/



This archive was generated by hypermail 2b29 : Wed Feb 20 2002 - 04:30:05 EST