source: coopr.pyomo/trunk/coopr/pyomo/base/expr.py @ 3714

Revision 3714, 25.0 KB checked in by jdsiiro, 3 years ago (diff)

More tweaks of expression generation

  • only clone arguments to intrinsic functions if someone else holds a reference to the argument
  • avoid unnecessary call to as)numeric(_self) when generating operator expressions
  • Property svn:executable set to *
Line 
1#  _________________________________________________________________________
2#
3#  Coopr: A COmmon Optimization Python Repository
4#  Copyright (c) 2008 Sandia Corporation.
5#  This software is distributed under the BSD License.
6#  Under the terms of Contract DE-AC04-94AL85000 with Sandia Corporation,
7#  the U.S. Government retains certain rights in this software.
8#  For more information, see the Coopr README.txt file.
9#  _________________________________________________________________________
10
11from __future__ import division
12
13__all__ = ['Expression', '_LessThanExpression', '_LessThanOrEqualExpression',
14        '_EqualToExpression', '_IdentityExpression' , 'generate_expression']
15
16from plugin import *
17from pyutilib.component.core import *
18from numvalue import *
19from param import _ParamBase
20from var import _VarBase
21
22import sys
23import copy
24import logging
25import StringIO
26
27logger = logging.getLogger('coopr.pyomo')
28
29class Expression(NumericValue):
30    """An object that defines a mathematical expression that can be evaluated"""
31
32    __slots__ = ('_args', )
33
34    def __init__(self, name, nargs, args):
35        """Construct an expression with an operation and a set of arguments"""
36        NumericValue.__init__(self, name=name)
37        self._args=args
38        if nargs and nargs != len(args):
39            raise ValueError, "%s() takes exactly %d arguments (%d given)" % \
40                ( name, nargs, len(args) )
41
42    def __getstate__(self):
43       result = NumericValue.__getstate__(self)
44       for i in Expression.__slots__:
45          result[i] = getattr(self, i)
46       return result
47   
48    def pprint(self, ostream=None, nested=True, eol_flag=True):
49        """Print this expression"""
50        if ostream is None:
51           ostream = sys.stdout
52        if nested:
53           if self.name is None:
54              print >>ostream, "Unnamed-Intrinsic" + "(",
55           else:
56              print >>ostream, self.name + "(",
57           first=True
58           for arg in self._args:
59             if first==False:
60                print >>ostream, ",",
61             if arg.is_expression():
62                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
63             else:
64                print >>ostream, str(arg),
65             first=False
66           if eol_flag==True:
67              print >>ostream, ")"
68           else:
69              print >>ostream, ")",
70
71    def clone(self):
72        """Clone this object using the specified arguments"""
73        raise NotImplementedError, "Derived expression (%s) failed to "\
74            "implement clone()" % ( str(self.__class__), )
75
76    def simplify(self, model):
77        print """
78WARNING: Expression.simplify() has been deprecated and removed from
79     Pyomo Expressions.  Please remove references to simplify() from your
80     code.
81"""
82        return self
83
84    #
85    # this method contrast with the fixed_value() method.
86    # the fixed_value() method returns true iff the value is
87    # an atomic constant.
88    # this method returns true iff all composite arguments
89    # in this sum expression are constant, i.e., numeric
90    # constants or parametrs. the parameter values can of
91    # course change over time, but at any point in time,
92    # they are constant. hence, the name.
93    #
94    def is_constant(self):
95        for arg in self._args:
96            if not arg.is_constant():
97                return False
98        return True
99
100    def is_expression(self):
101        return True
102
103    def polynomial_degree(self):
104        return None
105
106    def __nonzero__(self):
107        val = self()
108        if val:
109            return True
110        return False
111
112    def __call__(self, exception=True):
113        """Evaluate the expression"""
114        try:
115            return self._apply_operation(self._evaluate_arglist(self._args))
116        except (ValueError, TypeError):
117            if exception:
118                raise
119            return None
120
121    def _evaluate_arglist(self, arglist):
122        for arg in arglist:
123            try:
124                yield value(arg)
125            except Exception, e:
126                buffer = StringIO.StringIO()
127                self.pprint(buffer)
128                logger.error("evaluating expression: %s\nExpression: %s",
129                             str(e), buffer.getvalue())
130                raise
131
132    def _apply_operation(self, values):
133        """Method that can be overwritten to define the operation in
134        this expression"""
135        raise NotImplementedError, "Derived expression (%s) failed to "\
136            "implement _apply_operation()" % ( str(self.__class__), )
137       
138    def __str__(self):
139        return self.name
140
141
142class _IntrinsicFunctionExpression(Expression):
143    __slots__ = ('_operator', )
144
145    def __init__(self, name, nargs, args, operator):
146        """Construct an expression with an operation and a set of arguments"""
147        Expression.__init__(self, name, nargs, args)
148        self._operator = operator
149
150    def __getstate__(self):
151       result = Expression.__getstate__(self)
152       for i in _IntrinsicFunctionExpression.__slots__:
153          result[i] = getattr(self, i)
154       return result
155
156    def clone(self):
157        return self.__class__( self.name,
158                               None,
159                               copy.copy(self._args),
160                               self._operator )
161
162    def _apply_operation(self, values):
163        return self._operator(*tuple(values))
164
165
166# Should this actually be a special class, or just an instance of
167# _IntrinsicFunctionExpression (like sin, cos, etc)?
168class _AbsExpression(_IntrinsicFunctionExpression):
169
170    __slots__ = ()
171
172    def __init__(self, args):
173        _IntrinsicFunctionExpression.__init__(self, 'abs', 1, args, abs)
174
175    def __getstate__(self):
176       return _IntrinsicFunctionExpression.__getstate__(self)
177
178    def clone(self):
179        return self.__class__( copy.copy(self._args) )
180
181# Should this actually be a special class, or just an instance of
182# _IntrinsicFunctionExpression (like sin, cos, etc)?
183class _PowExpression(_IntrinsicFunctionExpression):
184
185    __slots__ = ()
186
187    def __init__(self, args):
188        _IntrinsicFunctionExpression.__init__(self, 'pow', 2, args, pow)
189
190    def __getstate__(self):
191       return _IntrinsicFunctionExpression.__getstate__(self)
192 
193    def clone(self):
194        return self.__class__( copy.copy(self._args) )
195
196    def polynomial_degree(self):
197        # Right now, all _PowExpressions are considered to be
198        # non-polynomial, even though there is a special case where
199        # _args[0] is polynomial and _args[1] is constant and a
200        # non-negative integer.  I am still choosing to return
201        # "nonpolynomial" because most consumers (like expression
202        # compilers) do not recognize that pow could actually be
203        # polynomial.
204        return None
205
206
207class _LinearExpression(Expression):
208
209    __slots__ = ()
210
211    def __init__(self, name, nargs, args):
212        """Constructor"""
213        Expression.__init__(self, name, nargs, args)
214
215    def __getstate__(self):
216       return Expression.__getstate__(self)
217
218    def clone(self):
219        return self.__class__( copy.copy(self._args) )
220
221    def polynomial_degree(self):
222        # NB: We can't use max() here because None (non-polynomial)
223        # overrides a numeric value (and max() just ignores it)
224        degree = 0
225        for x in self._args:
226            x_degree = x.polynomial_degree()
227            if x_degree is None:
228                return None
229            degree = max(degree, x_degree)
230        return degree
231
232
233class _LessThanExpression(_LinearExpression):
234    """An object that defines a less-than expression"""
235
236    __slots__ = ()
237
238    def __init__(self, args):
239        """Constructor"""
240        _LinearExpression.__init__(self, 'lt', 2, args)
241
242    def __getstate__(self):
243       return _LinearExpression.__getstate__(self)
244
245    def _apply_operation(self, values):
246        """Method that defines the less-than operation"""
247        return values.next() < values.next()
248
249
250class _LessThanOrEqualExpression(_LinearExpression):
251    """An object that defines a less-than-or-equal expression"""
252
253    __slots__ = ()   
254
255    def __init__(self, args):
256        """Constructor"""
257        _LinearExpression.__init__(self, 'le', 2, args)
258
259    def __getstate__(self):
260       return _LinearExpression.__getstate__(self)
261
262    def _apply_operation(self, values):
263        """Method that defines the less-than-or-equal operation"""
264        return values.next() <= values.next()
265
266
267class _EqualToExpression(_LinearExpression):
268    """An object that defines a equal-to expression"""
269
270    __slots__ = ()       
271
272    def __init__(self, args):
273        """Constructor"""
274        _LinearExpression.__init__(self, 'eq', 2, args)
275
276    def __getstate__(self):
277       return _LinearExpression.__getstate__(self)
278
279    def _apply_operation(self, values):
280        """Method that defines the equal-to operation"""
281        return values.next() == values.next()
282
283
284class _ProductExpression(Expression):
285    """An object that defines a product expression"""
286
287    __slots__ = ('_denominator','_numerator','coef')
288
289    def __init__(self):
290        """Constructor"""
291        Expression.__init__(self,name='prod',nargs=None, args=None)
292        self._denominator = []
293        self._numerator = []
294        self.coef = 1
295
296    def __getstate__(self):
297       result = Expression.__getstate__(self)
298       for i in _ProductExpression.__slots__:
299          result[i] = getattr(self, i)
300       return result       
301
302    def is_constant(self):
303        for arg in self._numerator:
304            if not arg.is_constant():
305                return False
306        for arg in self._denominator:
307            if not arg.is_constant():
308                return False
309        return True
310
311    def polynomial_degree(self):
312        for x in self._denominator:
313            if x.polynomial_degree() != 0:
314                return None
315        try:
316            return sum(x.polynomial_degree() for x in self._numerator)
317        except TypeError:
318            return None
319
320    def invert(self):
321        tmp = self._denominator
322        self._denominator = self._numerator
323        self._numerator = tmp
324        self.coef = 1.0/self.coef
325           
326    def pprint(self, ostream=None, nested=True, eol_flag=True):
327        """Print this expression"""
328        if ostream is None:
329           ostream = sys.stdout
330        if nested:
331           print >>ostream, self.name + "( num=(",
332           first=True
333           if self.coef != 1:
334                print >>ostream, str(self.coef),
335                first=False
336           for arg in self._numerator:
337             if first==False:
338                print >>ostream, ",",
339             if arg.is_expression():
340                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
341             else:
342                print >>ostream, str(arg),
343             first=False
344           if first is True:
345              print >>ostream, 1,
346           print >>ostream, ")",
347           if len(self._denominator) > 0:
348              print >>ostream, ", denom=(",
349              first=True
350              for arg in self._denominator:
351                if first==False:
352                   print >>ostream, ",",
353                if arg.is_expression():
354                   arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
355                else:
356                   print >>ostream, str(arg),
357                first=False
358              print >>ostream, ")",
359           print >>ostream, ")",
360           if eol_flag==True:
361              print >>ostream, ""
362
363    def clone(self):
364        """Clone this object using the specified arguments"""
365        tmp = self.__class__()
366        tmp.name = self.name
367        tmp._numerator = copy.copy(self._numerator)
368        tmp._denominator = copy.copy(self._denominator)
369        tmp.coef = self.coef
370        return tmp
371
372    def __call__(self, exception=True):
373        """Evaluate the expression"""
374        try:
375            ans = self.coef
376            for n in self._evaluate_arglist(self._numerator):
377                ans *= n
378            for n in self._evaluate_arglist(self._denominator):
379                ans /= n
380            return ans
381        except (TypeError, ValueError):
382            if exception:
383                raise
384            return None
385
386
387class _IdentityExpression(_LinearExpression):
388    """An object that defines a identity expression"""
389
390    __slots__ = ()
391
392    def __init__(self, args):
393        """Constructor"""
394        arg_type = type(args)
395        if arg_type is list or arg_type is tuple:
396           Expression.__init__(self, 'identity', 1, args)
397        else:
398           Expression.__init__(self, 'identity', 1, [args])
399
400    def __getstate__(self):
401       return _LinearExpression.__getstate__(self)
402
403    def _apply_operation(self, values):
404        """Method that defines the identity operation"""
405        return values.next()
406
407
408class _SumExpression(_LinearExpression):
409    """An object that defines a weighted summation of expressions"""
410
411    __slots__ = ('_coef','_const')
412
413    def __init__(self):
414        """Constructor"""
415        _LinearExpression.__init__(self, 'sum', None, [])
416        self._coef = []
417        self._const = 0
418
419    def __getstate__(self):
420       result = _LinearExpression.__getstate__(self)
421       for i in _SumExpression.__slots__:
422          result[i] = getattr(self, i)
423       return result       
424
425    def clone(self):
426        """Clone this object using the specified arguments"""
427        tmp = self.__class__()
428        tmp.name = self.name
429        tmp._args = copy.copy(self._args)
430        tmp._coef = copy.copy(self._coef)
431        tmp._const = self._const
432        return tmp
433
434    def scale(self, val):
435        for i in xrange(len(self._coef)):
436            self._coef[i] *= val
437        self._const *= val
438
439    def negate(self):
440        self.scale(-1)
441
442    def pprint(self, ostream=None, nested=True, eol_flag=True):
443        """Print this expression"""
444        if ostream is None:
445           ostream = sys.stdout
446        if nested:
447           print >>ostream, self.name + "(",
448           first=True
449           if self._const != 0:
450                print >>ostream, str(self._const),
451                first=False
452           for i, arg in enumerate(self._args):
453             if first==False:
454                print >>ostream, ",",
455             if self._coef[i] != 1:
456                print >>ostream, str(self._coef[i])+" * ",
457             if arg.is_expression():
458                arg.pprint(ostream=ostream, nested=nested, eol_flag=False)
459             else:
460                print >>ostream, str(arg),
461             first=False
462           print >>ostream, ")",
463           if eol_flag==True:
464              print >>ostream, ""
465
466    def _apply_operation(self, values):
467        """Evaluate the expression"""
468        return sum(c*values.next() for c in self._coef) + self._const
469
470
471
472def generate_expression(*_args):
473    def _clone_if_needed(obj):
474        count = sys.getrefcount(obj) - generate_expression.UNREFERENCED_EXPR_COUNT
475        if generate_expression.clone_if_needed_callback:
476            generate_expression.clone_if_needed_callback(count)
477        if count == 0:
478            return obj
479        elif count > 0:
480            generate_expression.clone_counter += 1
481            return obj.clone()
482        else:
483            raise RuntimeError, "Expression entered generate_expression() " \
484                "with too few references (%s<0); this is indicative of a " \
485                "SERIOUS ERROR in the expression reuse detection scheme." \
486                % ( count, )
487
488    etype = _args[0]
489    _self = _args[1]
490    if _self.is_expression():
491        _self = _clone_if_needed(_self)
492    elif _self.is_indexed():
493        raise ValueError, "Argument for expression '%s' is an n-ary "\
494            "numeric value: %s\n    Have you given variable or "\
495            "parameter '%s' an index?" % (etype, _self.name, _self.name)
496       
497    #
498    # First, handle the special cases of unary opertors
499    #
500    if etype == 'neg':
501        if type(_self) is _SumExpression:
502            _self.negate()
503            return _self
504        else:
505            etype = 'rmul'
506            other = NumericConstant(value=-1)
507    elif etype == 'abs':
508        return _AbsExpression([_self])
509    else:
510        other = as_numeric(_args[2])
511        if other.is_expression():
512            other = _clone_if_needed(other)
513        elif other.is_indexed():
514            raise ValueError, "Argument for expression '%s' is an n-ary "\
515                "numeric value: %s\n    Have you given variable or "\
516                "parameter '%s' an index?" % (etype, _self.name, _self.name)
517
518    #
519    # Binary operators can either be "normal" or "reversed"; reverse all
520    # "reversed" opertors
521    #
522    if etype[0] == 'r':
523        #
524        # This may seem obvious, but if we are performing an
525        # "R"-operation (i.e. reverse operation), then simply reverse
526        # self and other.  This is legitimate as we are generating a
527        # completely new expression here, and the _clone_if_needed logic
528        # above will make sure that we don't accidentally clobber
529        # someone else's expression (fragment).
530        #
531        tmp = _self
532        _self = other
533        other = tmp
534        etype = etype[1:]
535
536    #
537    # Some binary operators are special cases of other operators
538    #
539    multiplier = 1
540    if etype == 'sub':
541        multiplier = -1
542        etype = 'add'
543
544    #
545    # Now, handle all binary operators
546    #
547    if etype == 'add':
548        #
549        # self + other
550        #
551        self_type = type(_self)
552        other_type = type(other)
553        if self_type is _SumExpression:
554            if other_type is NumericConstant:
555                _self._const += multiplier * other()
556            elif other_type is _ProductExpression \
557                     and len(other._numerator) == 1 \
558                     and not other._denominator:
559                _self._args.append(other._numerator[0])
560                _self._coef.append(multiplier*other.coef)
561            elif other_type is _SumExpression:
562                if multiplier < 0:
563                    other.negate()
564                _self._args.extend(other._args)
565                _self._coef.extend(other._coef)
566                _self._const += other._const
567            else:
568                _self._args.append(other)
569                _self._coef.append(multiplier)
570            ans = _self
571        else:
572            if other_type is _SumExpression:
573                if multiplier < 0:
574                    other.negate()
575                if self_type is NumericConstant:
576                    other._const += _self()
577                elif self_type is _ProductExpression and \
578                         len(_self._numerator) == 1 and \
579                         not _self._denominator:
580                    other._args.append(_self._numerator[0])
581                    other._coef.append(_self.coef)
582                else:
583                    other._args.append(_self)
584                    other._coef.append(1)
585                ans = other
586            else:
587                ans = _SumExpression()
588                if self_type is NumericConstant:
589                    ans._const += _self()
590                elif self_type is _ProductExpression and \
591                         len(_self._numerator) == 1 and \
592                         not _self._denominator:
593                    ans._args.append(_self._numerator[0])
594                    ans._coef.append( _self.coef)
595                else:
596                    ans._args.append(_self)
597                    ans._coef.append(1)
598                if other_type is NumericConstant:
599                    ans._const += multiplier * other()
600                elif other_type is _ProductExpression and \
601                         len(other._numerator) == 1 and \
602                         not other._denominator:
603                    ans._args.append(other._numerator[0])
604                    ans._coef.append( multiplier*other.coef)
605                else:
606                    ans._args.append(other)
607                    ans._coef.append( multiplier)
608
609        # Special cases for simplifying expressions
610        if ans._const == 0 and len(ans._args) == 1 and ans._coef[0] == 1:
611            return ans._args[0]
612        return ans
613
614    elif etype == 'mul':
615        #
616        # self * other
617        #
618        self_type = type(_self)
619        other_type = type(other)
620        if self_type is _ProductExpression:
621            if other_type is NumericConstant:
622                _self.coef *= other()
623            elif other_type is _ProductExpression:
624                _self._numerator.extend(other._numerator)
625                _self._denominator.extend(other._denominator)
626                _self.coef *= other.coef
627            else:
628                _self._numerator.append(other)
629            ans = _self
630        else:
631            if other_type is _ProductExpression:
632                if self_type is NumericConstant:
633                    other.coef *= _self()
634                else:
635                    other._numerator.append(_self)
636                ans = other
637            else:
638                ans = _ProductExpression()
639                if self_type is NumericConstant:
640                    ans.coef *= _self()
641                else:
642                    ans._numerator.append(_self)
643                if other_type is NumericConstant:
644                    ans.coef *= other()
645                else:
646                    ans._numerator.append(other)
647
648        # Special cases for simplifying expressions
649        if ans.coef == 0:
650            return NumericConstant(value=0)
651        if ans.coef == 1 and len(ans._numerator) == 1 and not ans._denominator:
652            return ans._numerator[0]
653        return ans
654
655    elif etype == 'div':
656        #
657        # self / other
658        #
659        self_type = type(_self)
660        other_type = type(other)
661        if self_type is _ProductExpression:
662            if other_type is NumericConstant:
663                _self.coef /= other()
664            elif other_type is _ProductExpression:
665                _self._numerator.extend(other._denominator)
666                _self._denominator.extend(other._numerator)
667                _self.coef /= other.coef
668            else:
669                _self._denominator.append(other)
670            ans = _self
671        else:
672            if other_type is _ProductExpression:
673                other.invert()
674                if self_type is NumericConstant:
675                    other.coef *= _self()
676                else:
677                    other._numerator.append(_self)
678                ans = other
679            else:
680                ans = _ProductExpression()
681                if self_type is NumericConstant:
682                    ans.coef *= _self()
683                else:
684                    ans._numerator.append(_self)
685                if other_type is NumericConstant:
686                    ans.coef /= other()
687                else:
688                    ans._denominator.append(other)
689
690        # Special cases for simplifying expressions
691        if ans.coef == 0:
692            return NumericConstant(value=0)
693        if ans.coef == 1 and len(ans._numerator) == 1 and not ans._denominator:
694            return ans._numerator[0]
695        return ans
696
697    elif etype == 'pow':
698        #
699        # self ** other
700        #
701        return _PowExpression((_self, other))
702
703    else:
704        raise RuntimeError, "Unknown expression type '%s'" % etype
705
706##
707## "static" variables within the generate_expression function
708##
709
710# [testing] clone_check_callback allows test functions to intercept the
711# call to _clone_if_needed and see the returned value from
712# sys.getrefcount.
713generate_expression.clone_if_needed_callback = None
714
715# [debugging] clone_counter is a count of the number of calls to
716# expr.clone() made during expression generation.
717generate_expression.clone_counter = 0
718
719# [configuration] UNREFERENCED_EXPR_COUNT is a "magic number" that
720# indicates the stack depth between "normal" modeling and
721# _clone_if_needed().  If an expression enters _clone_if_needed() with
722# UNREFERENCED_EXPR_COUNT references, then there are no other variables
723# that hold a reference to the expression and cloning is not necessary.
724# If there are more references than UNREFERENCED_EXPR_COUNT, then we
725# must clone the expression before operating on it.  It should be an
726# error to hit _clone_if_needed() with fewer than
727# UNREFERENCED_EXPR_COUNT references.
728generate_expression.UNREFERENCED_EXPR_COUNT = 10
729
730
731ExpressionRegistration('<', _LessThanExpression)
732ExpressionRegistration('lt', _LessThanExpression)
733ExpressionRegistration('>', _LessThanExpression, True)
734ExpressionRegistration('gt', _LessThanExpression, True)
735ExpressionRegistration('<=', _LessThanOrEqualExpression)
736ExpressionRegistration('lte', _LessThanOrEqualExpression)
737ExpressionRegistration('>=', _LessThanOrEqualExpression, True)
738ExpressionRegistration('gte', _LessThanOrEqualExpression, True)
739ExpressionRegistration('=', _EqualToExpression)
740ExpressionRegistration('eq', _EqualToExpression)
741
742if False:
743    ExpressionRegistration('+', _SumExpression)
744    ExpressionRegistration('sum', _SumExpression)
745    ExpressionRegistration('*', _ProductExpression)
746    ExpressionRegistration('prod', _ProductExpression)
747    #ExpressionRegistration('-', _MinusExpression)
748    #ExpressionRegistration('minus', _MinusExpression)
749    #ExpressionRegistration('/', _DivisionExpression)
750    #ExpressionRegistration('divide', _DivisionExpression)
751    #ExpressionRegistration('-', _NegateExpression)
752    #ExpressionRegistration('negate', _NegateExpression)
753    ExpressionRegistration('abs', _AbsExpression)
754    ExpressionRegistration('pow', _PowExpression)
755
Note: See TracBrowser for help on using the repository browser.