// Union value class -*- c++ -*-

#include "snprintf.h"

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "UnionValue.h"
#include "UnionType.h"
#include "Constraint.h"

/** @file UnionValue.C
 * Union value
 */

/* Copyright  1999-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

UnionValue::UnionValue (const class Type& type,
			card_t i,
			class Value& value) :
  Value (type),
  myIndex (i), myValue (&value)
{
  assert (getType ().getKind () == Type::tUnion);
  assert (&myValue->getType () ==
	  &static_cast<const class UnionType&>(getType ())[myIndex]);
}

UnionValue::UnionValue (const class UnionValue& old) :
  Value (old.getType ()),
  myIndex (old.myIndex), myValue (old.myValue->copy ())
{
  assert (getType ().getKind () == Type::tUnion);
  assert (myValue && &myValue->getType () ==
	  &static_cast<const class UnionType&>(getType ())[myIndex]);
}

UnionValue::~UnionValue ()
{
  delete myValue;
}

bool
UnionValue::operator< (const class UnionValue& other) const
{
  assert (&getType () == &other.getType ());

  if (myIndex < other.myIndex)
    return true;
  else if (other.myIndex < myIndex)
    return false;
  else
    return getValue () < other.getValue ();
}

bool
UnionValue::operator== (const class UnionValue& other) const
{
  assert (&getType () == &other.getType ());

  return myIndex == other.myIndex && getValue () == other.getValue ();
}

card_t
UnionValue::operator- (const class UnionValue& other) const
{
  const class UnionType& type =
    static_cast<const class UnionType&>(getType ());
  card_t diff =
    type.getCumulatedValues (myIndex) -
    type.getCumulatedValues (other.myIndex);
  if (myIndex == other.myIndex)
    diff += getValue () - other.getValue ();
  else
    diff +=
      type[myIndex].convert (getValue ()) -
      type[other.myIndex].convert (other.getValue ());

  return diff;
}

void
UnionValue::bottom ()
{
  delete myValue;

  if (const class Constraint* c = getType ().getConstraint ()) {
    const class Value& v = c->getFirstValue ();
    assert (&v.getType () == &getType () && v.getKind () == getKind ());
    const class UnionValue& uv = static_cast<const class UnionValue&>(v);
    myIndex = uv.myIndex;
    myValue = uv.myValue->copy ();
    return;
  }

  myValue = &static_cast<const class UnionType&>(getType ())
    [myIndex = 0].getFirstValue ();
}

void
UnionValue::top ()
{
  delete myValue;

  if (const class Constraint* c = getType ().getConstraint ()) {
    const class Value& v = c->getLastValue ();
    assert (&v.getType () == &getType () && v.getKind () == getKind ());
    const class UnionValue& uv = static_cast<const class UnionValue&>(v);
    myIndex = uv.myIndex;
    myValue = uv.myValue->copy ();
    return;
  }

  const class UnionType& type =
    static_cast<const class UnionType&>(getType ());
  myValue = &type[myIndex = type.getSize () - 1].getLastValue ();
}

bool
UnionValue::increment ()
{
  if (const class Constraint* c = getType ().getConstraint ()) {
    const class Value* v = &c->getNextHigh (*this);
    assert (&v->getType () == &getType () && v->getKind () == getKind ());
    if (*this == *static_cast<const class UnionValue*>(v)) {
      if (!(v = c->getNextLow (*this))) {
	bottom ();
	return false;
      }
      assert (&v->getType () == &getType () && v->getKind () == getKind ());
      const class UnionValue& uv = *static_cast<const class UnionValue*>(v);
      delete myValue;
      myIndex = uv.myIndex;
      myValue = uv.myValue->copy ();
      return true;
    }
  }

  if (myValue->increment ())
    return true;

  if (++myIndex == static_cast<const class UnionType&>(getType ()).getSize ())
    myIndex = 0;

  delete myValue;
  myValue = &static_cast<const class UnionType&>(getType ())
    [myIndex].getFirstValue ();

  return myIndex != 0;
}

bool
UnionValue::decrement ()
{
  if (const class Constraint* c = getType ().getConstraint ()) {
    const class Value* v = &c->getPrevLow (*this);
    assert (&v->getType () == &getType () && v->getKind () == getKind ());
    if (*this == *static_cast<const class UnionValue*>(v)) {
      if (!(v = c->getPrevHigh (*this))) {
	top ();
	return false;
      }
      assert (&v->getType () == &getType () && v->getKind () == getKind ());
      const class UnionValue& uv = *static_cast<const class UnionValue*>(v);
      delete myValue;
      myIndex = uv.myIndex;
      myValue = uv.myValue->copy ();
      return true;
    }
  }

  if (myValue->decrement ())
    return true;

  if (!myIndex) {
    top ();
    return false;
  }

  delete myValue;
  myValue = &static_cast<const class UnionType&>(getType ())
    [--myIndex].getLastValue ();
  return true;
}

class Value*
UnionValue::cast (const class Type& type)
{
  if (type.getKind () == Type::tUnion)
    return Value::cast (type);
  const class UnionType& ut =
    static_cast<const class UnionType&>(getType ());
  if (!ut[myIndex].isAssignable (type)) {
    delete this;
    return NULL;
  }
  class Value* v = myValue;
  myValue = NULL;
  delete this;
  return v;
}

#include "Printer.h"

void
UnionValue::display (const class Printer& printer) const
{
  printer.print (static_cast<const class UnionType&>(getType ())
		 .getComponentName (myIndex));
  printer.delimiter ('=')++;
  myValue->display (printer);
  printer--;
}

#ifdef EXPR_COMPILE
# include "StringBuffer.h"
# include <stdio.h>

/** Get a C name of this union lvalue
 * @param prefix	prefix C name of the base lvalue
 * @param i		index number of the active union component
 */
static char*
getLvalueName (const char* prefix, unsigned i)
{
  const size_t length = strlen (prefix), ilength = 25;
  char* name = new char[length + ilength];
  memcpy (name, prefix, length);
  snprintf (name + length, ilength, ".u.u%u", i);
  return name;
}

void
UnionValue::compile (class StringBuffer&) const
{
  assert (false);
}

void
UnionValue::compileInit (const char* name,
			 unsigned indent,
			 class StringBuffer& out) const
{
  char* ixname = getLvalueName (name, myIndex);
  out.indent (indent);
  out.append (name);
  out.append (".t=");
  out.append (myIndex);
  out.append (indent ? ";\n" : ", ");
  myValue->compileInit (ixname, indent, out);
  delete[] ixname;
}

bool
UnionValue::compileEqual (class StringBuffer& out,
			  unsigned indent,
			  const char* var,
			  bool equal,
			  bool first,
			  bool last) const
{
  char* ixname = getLvalueName (var, myIndex);
  if (!first)
    out.indent (indent);
  out.append (var);
  out.append (".t");
  out.append (equal ? "==" : "!=");
  out.append (myIndex);
  out.append (equal ? "&&\n" : "||\n");
  if (!myValue->compileEqual (out, indent, ixname, equal, false, last) && last)
    out.chop (3);
  delete[] ixname;
  return true;
}

unsigned
UnionValue::compileOrder (class StringBuffer& out,
			  unsigned indent,
			  const char* var,
			  bool less,
			  bool equal,
			  bool first,
			  bool last) const
{
  assert (!equal || last);
  char* ixname = getLvalueName (var, myIndex);
  if (!first)
    out.indent (indent);
  out.openParen (1);
  out.append (var);
  out.append (less ? ".t<" : ".t>");
  out.append (myIndex);
  out.append ("||\n");
  out.indent (indent + 1);
  out.openParen (1);
  out.append (var);
  out.append (".t==");
  out.append (myIndex);
  out.append ("&&\n");
  myValue->compileOrder (out, indent, ixname, less, equal, false, true);
  out.closeParen (2);
  if (!last)
    out.append ("||\n");
  delete[] ixname;
  return 0;
}

#endif // EXPR_COMPILE
