#! /usr/bin/env python
##
## A python script that reads an XML input file, computes a simple function, 
## and writes out an XML file.
##
## Note: this script can be customized for new applications.  The CUSTOMIZE
## HERE notes indicate the parts of this script that would need to be edited.
##

import sys
import xml.dom.minidom
import re

#
# UTILITY FUNCTION: create strings from arrays of numbers
#
def tostr(array):
  tmpstr = ""
  for val in array:
    tmpstr = tmpstr + " " + `val`
  return tmpstr

#
# UTILITY FUNCTION: get text for a node
#
def get_text(node):
  nodetext = ""
  for child in node.childNodes:
    if child.nodeType == node.TEXT_NODE:
       nodetext = nodetext + child.nodeValue
  nodetext = str(nodetext)
  return nodetext.strip()


#
# The type of parameters supported by this application.
#
# This class uses an XML node to initialize the object
# in the constructor.
#
# CUSTOMIZE HERE - To support alternative domain types
#
class MixedIntVars(object):
  def __init__(self, node):
      self.reals = []
      self.ints = []
      self.bits = []
      self.process(node)

  def display(self):
      print "Reals",
      for val in self.reals:
	print val,
      print ""
      print "Integers",
      for val in self.ints:
	print val,
      print ""
      print "Binary",
      for val in self.bits:
	print val,
      print ""

  def process(self,node):
      for child in node.childNodes:
        if child.nodeType == node.ELEMENT_NODE:
           child_text = get_text(child)
           if child_text == "":
              continue
           if child.nodeName == "Real":
              for val in re.split('[\t ]+',child_text):
                self.reals.append(1.0*eval(val))
           if child.nodeName == "Integer":
              for val in re.split('[\t ]+',child_text):
                self.ints.append(eval(val))
           if child.nodeName == "Binary":
              for val in child_text:
	        if val == '1':
                   self.bits.append(1)
	        elif val == '0':
                   self.bits.append(0)

#
# A table of response types supported by this application
#
# CUSTOMIZE HERE - if other types of responses are needed
#
xml_response_names = ["FunctionValue", "Gradient", "ConstraintValues"];

#
# The test function
#
# CUSTOMIZE HERE
#
def compute_response(point,requests):
  response = {}
  for key in requests.keys():
    if key not in xml_response_names:
       continue
    #
    if key == "FunctionValue":
       val=0.0
       for i in range(0,len(point.reals)):
         val = (i+1)*point.reals[i]
       for i in range(0,len(point.ints)):
         val = 1000*(i+1)*point.ints[i]
       for i in range(0,len(point.bits)):
         val = 1000000*(i+1)*point.bits[i]
       response[key] = str(val)
    #
    elif key == "Gradient":
       val = []
       for i in range(0,len(point.reals)):
	 val = val + [ (i+1) ]
       response[key] = tostr(val)
    #
    elif key == "ConstraintValues":
       val = []
       for i in range(0,len(point.reals)):
         val = val + [(i+1) + point.reals[i]*point.reals[i]]
       for i in range(0,len(point.ints)):
         val = val + [(i+1) + 1000*point.reals[i]*point.reals[i]]
       for i in range(0,len(point.bits)):
         val = val + [(i+1) + 1000000*point.reals[i]*point.reals[i]]
       response[key] = tostr(val)
  return response


#
# Process the document
#
def process(point,requests):
  #
  # Compute response info
  #
  response = compute_response(point,requests)
  #
  # Setup document
  #
  doc = xml.dom.minidom.Document()
  root = doc.createElement("ColinResponse")
  doc.appendChild(root)
  for key in requests.keys():
    if key in response.keys():
       elt = doc.createElement(key)
       root.appendChild(elt)
       text_elt = doc.createTextNode( response[key] )
       elt.appendChild(text_elt)
    else:
       elt = doc.createElement(key)
       root.appendChild(elt)
       text_elt = doc.createTextNode( "ERROR: Unsupported application request" )
       elt.appendChild(text_elt)
  return doc
       
#
# A function that processes the requests
#
def handleRequests(node):
  requests = {}
  for child in node.childNodes:
    if child.nodeType == node.ELEMENT_NODE:
       tmp = {}
       for (name,value) in child.attributes.items():
	 tmp[name]=value
       requests[str(child.nodeName)] = tmp
  return requests

##
## MAIN
##
if len(sys.argv) < 3:
   print "shell_func <input> <output> <log>"
   sys.exit(1)
#
# Parse XML input file
#
input_doc = xml.dom.minidom.parse(sys.argv[1])
point = MixedIntVars(input_doc.getElementsByTagName("Domain")[0])
requests = handleRequests(input_doc.getElementsByTagName("Requests")[0])
#
# Create output XML object
#
output_doc = process(point,requests)
OUTPUT = open(sys.argv[2],"w")
output_doc.writexml(OUTPUT," "," ","\n","UTF-8")
OUTPUT.close()
