Hi Todd,
Thanks for your analysis. Since I sent you my original suggestion, I have
come up with further refinements that will probably interest you. These
refinements obviate the need for extra operator overloadings.
The basic idea is to use inheritance instead of containment when defining
your _bz_ArrayExpr<> (see Expr<> in my example and note that the old Expr<>
has been renamed ExprBase<>).
The new Expr<E> wraps an E by inheriting from E. It also presents E's
constructor interface for use in constructing Expr<E> objects. Expr<E> can
now be used in operator overloadings replacing the operators defined for
UnaryExpr<>, BinaryExpr<>, TernaryExpr<> (I think you call this trinary),
etc. So we are back down to 3 signatures for the binary operators.
More importantly, the expression tree contains far fewer nodes than it
currently does in blitz++. I believe that this simplification helps the
compiler to produce more optimal code.
The new version of array.cc is attached and it contains other minor bug
fixes and improvements. Also, do you have an opinion on the use of operator
tag classes? I think they are an elegant way to express the operators.
cheers,
Greg
//---------------
// array.cc
//---------------
#include <stdlib.h>
#include <iostream.h>
#include <math.h>
template <class T> class Array;
template <class E> class ScalarExpr;
template <class A, class B, class Op> class BinaryExpr;
template <class A, class Op> class UnaryExpr;
template <class AType, class BType, class Op> class BinaryOp;
template <class AType, class Op> class UnaryOp;
//-----------------------------
// Level
// (for printing expressions)
//-----------------------------
class Level
{
public:
Level(int level)
: _level(level)
{}
friend std::ostream& operator<< (std::ostream& os, const Level& level)
{
for (int i = 0; i < level._level; i++)
os << "| ";
return os;
}
private:
int _level;
};
//-----------------------
// PromoteTrait<T1,T2>
//-----------------------
template <class T1, class T2>
struct PromoteTrait
{
typedef T1 Type; // default to first type
};
struct PromoteTrait<bool, char> { typedef int Type; };
struct PromoteTrait<bool, short> { typedef int Type; };
struct PromoteTrait<bool, int> { typedef int Type; };
struct PromoteTrait<bool, long> { typedef long Type; };
struct PromoteTrait<bool, float> { typedef float Type; };
struct PromoteTrait<bool, double> { typedef double Type; };
struct PromoteTrait<char, short> { typedef int Type; };
struct PromoteTrait<char, int> { typedef int Type; };
struct PromoteTrait<char, long> { typedef long Type; };
struct PromoteTrait<char, float> { typedef float Type; };
struct PromoteTrait<char, double> { typedef double Type; };
struct PromoteTrait<short, int> { typedef int Type; };
struct PromoteTrait<short, long> { typedef long Type; };
struct PromoteTrait<short, float> { typedef float Type; };
struct PromoteTrait<short, double> { typedef double Type; };
struct PromoteTrait<int, long> { typedef long Type; };
struct PromoteTrait<int, float> { typedef float Type; };
struct PromoteTrait<int, double> { typedef double Type; };
struct PromoteTrait<long, float> { typedef float Type; };
struct PromoteTrait<long, double> { typedef double Type; };
struct PromoteTrait<float, double> { typedef double Type; };
//-----------------------
// Expr<>
//-----------------------
template <class E>
class Expr {};
template <class A, class B, class Op>
class Expr< BinaryExpr<A,B,Op> > : public BinaryExpr<A,B,Op>
{
public:
Expr(const A& a, const B& b)
: BinaryExpr<A,B,Op>(a,b)
{}
};
template <class A, class Op>
class Expr< UnaryExpr<A,Op> > : public UnaryExpr<A,Op>
{
public:
explicit Expr(const A& a)
: UnaryExpr<A,Op>(a)
{}
};
//-----------------------
// MakeExpr<>
//-----------------------
template <class E>
struct MakeExpr
{
typedef ScalarExpr<E> ExprType;
typedef ExprType::Type ResultType;
};
template <class E>
struct MakeExpr< ScalarExpr<E> >
{
typedef ScalarExpr<E> ExprType;
typedef ExprType::Type ResultType;
};
template <class T>
struct MakeExpr< Array<T> >
{
typedef const Array<T>& ExprType;
typedef T ResultType;
};
template <class A, class B, class Op>
struct MakeExpr< BinaryExpr<A,B,Op> >
{
typedef BinaryExpr<A,B,Op> ExprType;
typedef ExprType::Type ResultType;
};
template <class A, class Op>
struct MakeExpr< UnaryExpr<A,Op> >
{
typedef UnaryExpr<A,Op> ExprType;
typedef ExprType::Type ResultType;
};
template <class E>
struct MakeExpr< Expr<E> >
{
typedef MakeExpr<E>::ExprType ExprType;
typedef ExprType::Type ResultType;
};
//-----------------------
// Base Expression:
// ExprBase<E>
//-----------------------
template <class E>
class ExprBase {};
//-----------------------
// Scalar Expression:
// ScalarExpr<E>
//-----------------------
//
// default to scalar expression
//
template <class S>
class ScalarExpr : public ExprBase< ScalarExpr<S> >
{
public:
typedef S Type;
ScalarExpr(const S& s)
: _s(s)
{}
Type operator() (int i) const
{ return _s; }
void treePrint(std::ostream& os, int level = 0) const
{
os << Level(level) << "ScalarExpr:" << endl;
os << Level(level+1) << _s << endl;
}
private:
const S _s;
};
//-----------------------------
// Binary Expression:
// BinaryExpr<A,B,Op>
//-----------------------------
template <class A, class B, class Op>
class BinaryExpr : public ExprBase< BinaryExpr<A,B,Op> >
{
public:
typedef MakeExpr<A>::ExprType AExpr;
typedef MakeExpr<B>::ExprType BExpr;
typedef MakeExpr<A>::ResultType AType;
typedef MakeExpr<B>::ResultType BType;
typedef BinaryOp<AType, BType, Op> Operation;
typedef Operation::Type Type;
BinaryExpr(const A& a, const B& b)
: _a(a), _b(b)
{}
Type operator() (int i) const
{ return Operation::apply(_a(i), _b(i)); }
void treePrint(std::ostream& os, int level = 0) const
{
os << Level(level) << "BinaryExpr:" << endl;
_a.treePrint(os, level+1);
_b.treePrint(os, level+1);
Operation::treePrint(os, level+1);
}
private:
AExpr _a;
BExpr _b;
};
//-----------------------------
// Unary Expression:
// UnaryExpr<A,Op>
//-----------------------------
template <class A, class Op>
class UnaryExpr : public ExprBase< UnaryExpr<A,Op> >
{
public:
typedef MakeExpr<A>::ExprType AExpr;
typedef MakeExpr<A>::ResultType AType;
typedef UnaryOp<AType, Op> Operation;
typedef Operation::Type Type;
UnaryExpr(const A& a)
: _a(a)
{}
Type operator() (int i) const
{ return Operation::apply(_a(i)); }
void treePrint(std::ostream& os, int level = 0) const
{
os << Level(level) << "UnaryExpr:" << endl;
_a.treePrint(os, level+1);
Operation::treePrint(os, level+1);
}
private:
AExpr _a;
};
//=======================
// Operators
//=======================
//--------------------------
// BinaryOp<AType,BType,Op>
//--------------------------
template <class AType, class BType, class Op>
class BinaryOp {}; // default: never instantiated
#define BINARY_OP_PROMOTE(Op,Name)
\
template <class AType, class BType>
\
class BinaryOp<AType, BType, Name>
\
{
\
public:
\
typedef PromoteTrait<AType, BType>::Type Type;
\
static Type apply(const AType& a, const BType& b)
\
{ return a Op b; }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define BINARY_OP_RETTYPE(Op,Name,RetType)
\
template <class AType, class BType>
\
class BinaryOp<AType, BType, Name>
\
{
\
public:
\
typedef RetType Type;
\
static Type apply(const AType& a, const BType& b)
\
{ return a Op b; }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define BINARY_OP_FUNC_PROMOTE(FuncOp,Name)
\
template <class AType, class BType>
\
class BinaryOp<AType, BType, Name>
\
{
\
public:
\
typedef PromoteTrait<AType, BType>::Type Type;
\
static Type apply(const AType& a, const BType& b)
\
{ return FuncOp(a, b); }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define BINARY_OP_FUNC_RETTYPE(FuncOp,Name,RetType)
\
template <class AType, class BType>
\
class BinaryOp<AType, BType, Name>
\
{
\
public:
\
typedef RetType Type;
\
static Type apply(const AType& a, const BType& b)
\
{ return FuncOp(a, b); }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
//--------------------------------
// Binary operator overloadings
//--------------------------------
#define BINARY_EXPR_OPERATOR(FuncOp,Name)
\
\
class Name {};
\
\
template <class A, class B>
\
Expr< BinaryExpr<Array<A>,B,Name> >
\
FuncOp(const Array<A>& a, const B& b)
\
{
\
return Expr< BinaryExpr<Array<A>,B,Name> >(a, b);
\
}
\
\
template <class A, class B>
\
Expr< BinaryExpr<Expr<A>,B,Name> >
\
FuncOp(const Expr<A>& a, const B& b)
\
{
\
return Expr< BinaryExpr<Expr<A>,B,Name> >(a, b);
\
}
\
\
template <class A, class B>
\
Expr< BinaryExpr<A,B,Name> >
\
FuncOp(const A& a, const ExprBase<B>& b)
\
{
\
return Expr< BinaryExpr<A,B,Name> >(a, static_cast<const B&>(b));
\
}
\
#define DEF_BINARY_OP_PROMOTE(Op,Name) \
BINARY_EXPR_OPERATOR(operator Op, Name) \
BINARY_OP_PROMOTE(Op, Name) \
#define DEF_BINARY_OP_RETTYPE(Op,Name,RetType) \
BINARY_EXPR_OPERATOR(operator Op, Name) \
BINARY_OP_RETTYPE(Op, Name, RetType) \
#define DEF_BINARY_FUNC_PROMOTE(Func,Name) \
BINARY_EXPR_OPERATOR(Func, Name) \
BINARY_OP_FUNC_PROMOTE(Func, Name) \
#define DEF_BINARY_FUNC_RETTYPE(Func,Name,RetType) \
BINARY_EXPR_OPERATOR(Func, Name) \
BINARY_OP_FUNC_RETTYPE(Func, Name, RetType) \
DEF_BINARY_OP_PROMOTE( + , BinaryOp_Plus)
DEF_BINARY_OP_PROMOTE( - , BinaryOp_Minus)
DEF_BINARY_OP_PROMOTE( * , BinaryOp_Multiply)
DEF_BINARY_OP_PROMOTE( / , BinaryOp_Divide)
DEF_BINARY_OP_PROMOTE( % , BinaryOp_Modulus)
DEF_BINARY_OP_PROMOTE( & , BinaryOp_BitwiseAnd)
DEF_BINARY_OP_PROMOTE( | , BinaryOp_BitwiseOr)
DEF_BINARY_OP_PROMOTE( ^ , BinaryOp_BitwiseXor)
DEF_BINARY_OP_PROMOTE( >>, BinaryOp_ShiftRight)
DEF_BINARY_OP_PROMOTE( <<, BinaryOp_ShiftLeft)
DEF_BINARY_OP_RETTYPE( &&, BinaryOp_LogicalAnd, bool)
DEF_BINARY_OP_RETTYPE( ||, BinaryOp_LogicalOr, bool)
DEF_BINARY_OP_RETTYPE( > , BinaryOp_GreaterThan, bool)
DEF_BINARY_OP_RETTYPE( >=, BinaryOp_GreaterEqual, bool)
DEF_BINARY_OP_RETTYPE( < , BinaryOp_LessThan, bool)
DEF_BINARY_OP_RETTYPE( <=, BinaryOp_LessEqual, bool)
DEF_BINARY_OP_RETTYPE( ==, BinaryOp_Equal, bool)
DEF_BINARY_OP_RETTYPE( !=, BinaryOp_NotEqual, bool)
//-----------------------
// UnaryOp<AType,Op>
//-----------------------
template <class AType, class Op>
class UnaryOp {}; // default: never instantiated
#define UNARY_OP_DEFAULT(Op,Name)
\
template <class A>
\
class UnaryOp<A, Name>
\
{
\
public:
\
typedef A Type;
\
static Type apply(const Type& a)
\
{ return Op a; }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define UNARY_OP_RETTYPE(Op,Name,RetType)
\
template <class A>
\
class UnaryOp<A, Name>
\
{
\
public:
\
typedef RetType Type;
\
static Type apply(const Type& a)
\
{ return Op a; }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define UNARY_OP_FUNC_DEFAULT(FuncOp,Name)
\
template <class A>
\
class UnaryOp<A, Name>
\
{
\
public:
\
typedef A Type;
\
static Type apply(const Type& a)
\
{ return FuncOp(a); }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
#define UNARY_OP_FUNC_RETTYPE(FuncOp,Name,RetType)
\
template <class A>
\
class UnaryOp<A, Name>
\
{
\
public:
\
typedef RetType Type;
\
static Type apply(const Type& a)
\
{ return FuncOp(a); }
\
static void treePrint(std::ostream& os, int level = 0)
\
{ os << Level(level) << #Name << endl; }
\
};
\
//--------------------------------
// Unary operator overloadings
//--------------------------------
#define UNARY_EXPR_OPERATOR(FuncOp,Name)
\
\
class Name {};
\
\
template <class A>
\
Expr< UnaryExpr<A,Name> >
\
FuncOp(const ExprBase<A>& a)
\
{
\
return Expr< UnaryExpr<A,Name> >(static_cast<const A&>(a));
\
}
\
#define DEF_UNARY_OP_DEFAULT(Op,Name) \
UNARY_EXPR_OPERATOR(operator Op, Name) \
UNARY_OP_DEFAULT(Op, Name) \
#define DEF_UNARY_OP_RETTYPE(Op,Name,RetType) \
UNARY_EXPR_OPERATOR(operator Op, Name) \
UNARY_OP_RETTYPE(Op, Name, RetType) \
#define DEF_UNARY_FUNC_DEFAULT(Func,Name) \
UNARY_EXPR_OPERATOR(Func, Name) \
UNARY_OP_FUNC_DEFAULT(Func, Name) \
#define DEF_UNARY_FUNC_RETTYPE(Func,Name,RetType) \
UNARY_EXPR_OPERATOR(Func, Name) \
UNARY_OP_FUNC_RETTYPE(Func, Name, RetType) \
DEF_UNARY_OP_DEFAULT( - , UnaryOp_Minus)
DEF_UNARY_OP_DEFAULT( ~ , UnaryOp_Compliment)
DEF_UNARY_OP_RETTYPE( ! , UnaryOp_Not, bool)
DEF_UNARY_FUNC_DEFAULT(sin, UnaryOp_Sin)
DEF_UNARY_FUNC_DEFAULT(cos, UnaryOp_Cos)
DEF_UNARY_FUNC_DEFAULT(tan, UnaryOp_Tan)
//-----------------------
// Array
//-----------------------
static int uniqueID = 0;
template <class T>
class Array : public ExprBase< Array<T> >
{
public:
typedef T Type;
explicit Array(int size)
: _id(uniqueID++), _size(size), _data(new Type[size])
{}
~Array()
{ delete[] _data; }
int size() const
{ return _size; }
const Type& operator() (int i) const
{ return _data[i]; }
Type& operator() (int i)
{ return _data[i]; }
template <class E>
Array<Type>& operator= (const ExprBase<E>& e)
{
const E& expr = static_cast<const E&>(e);
for (int i = 0; i < _size; ++i)
_data[i] = expr(i);
return *this;
}
void treePrint(std::ostream& os, int level = 0) const
{ os << Level(level) << "Array" << _id << endl; }
friend std::ostream& operator<< (std::ostream& os, const Array<Type>& a)
{
os << "size " << a.size() << ":" << endl;
for (int i = 0; i < a.size(); ++i)
os << "data(" << i << ") = " << a(i) << endl;
return os;
}
private:
int _id;
int _size;
Type* _data;
};
//-----------------------
// Main
//-----------------------
int main(int argc, char* argv[])
{
Array<double> a(4); // Array0
Array<double> b(4); // Array1
Array<double> c(4); // Array2
Array<double> d(4); // Array3
Array<double> e(4); // Array4
a(0) = 0;
a(1) = 1;
a(2) = 2;
a(3) = 3;
b(0) = 0;
b(1) = 2;
b(2) = 4;
b(3) = 6;
cout << "a: " << a << endl;
cout << "b: " << b << endl;
cout << "Evaluate: c = a + b" << endl;
c = a + b;
cout << "c: " << c << endl;
cout << "--" << endl;
cout << "Expression: a + b" << endl;
(a + b).treePrint(cout);
cout << endl;
cout << "--" << endl;
d(0) = 3;
d(1) = 2;
d(2) = 1;
d(3) = 0;
cout << "a: " << a << endl;
cout << "c: " << c << endl;
cout << "Evaluate: e = (a + 2.0) * c" << endl;
e = (a + 2.0) * c;
cout << "e: " << e << endl;
cout << "--" << endl;
cout << "Expression: (a + 2.0) * c" << endl;
((a + 2.0) * c).treePrint(cout);
cout << endl;
cout << "--" << endl;
cout << "Expression: sin(-(a + 3) / (1 - b) * c)" << endl;
(sin(-(a + 3) / (1 - b) * c)).treePrint(cout);
cout << endl;
cout << "--" << endl;
d = (-(a + 3) / (1 - b) * c);
cout << "d: " << d << endl;
e = sin(-(a + 3) / (1 - b) * c);
cout << "e: " << e << endl;
cout << "--" << endl;
void funky(int, int, const Array<double>&);
funky(1,2,a);
funky(2,1,b);
}
void funky(int x, int y, const Array<double>& z)
{
cout << "Expression: (x + 2.0) * y + z" << endl;
((x + 2.0) * y + z).treePrint(cout);
cout << endl;
}
--------------------- blitz-dev list --------------------------------
* To subscribe/unsubscribe: mail to majordomo@oonumerics.org, with
"subscribe blitz-dev" or "unsubscribe blitz-dev" in the body of the message
* Blitz++ web page: http://oonumerics.org/blitz/
This archive was generated by hypermail 2b29 : Wed Feb 20 2002 - 04:30:07 EST