#!/usr/bin/env python
#
# odb.py
#
# Object Database Api
#
# Written by David Jeske <jeske@neotonic.com>, 2001/07. 
# Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998.
#
# Copyright (C) 2001, by David Jeske and Neotonic
#
# Goals:
#       - a simple object-like interface to database data
#       - database independent (someday)
#       - relational-style "rigid schema definition"
#       - object style easy-access
#
# Example:
#
#  import odb
#
#  # define table
#  class AgentsTable(odb.Table):
#    def _defineRows(self):
#      self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1)
#      self.d_addColumn("login",kVarString,200,notnull=1)
#      self.d_addColumn("ticket_count",kIncInteger,None)
#
#  if __name__ == "__main__":
#    # open database
#    ndb = MySQLdb.connect(host = 'localhost',
#                          user='username', 
#                          passwd = 'password', 
#                          db='testdb')
#    db = Database(ndb)
#    tbl = AgentsTable(db,"agents")
#
#    # create row
#    agent_row = tbl.newRow()
#    agent_row.login = "foo"
#    agent_row.save()
#
#    # fetch row (must use primary key)
#    try:
#      get_row = tbl.fetchRow( ('agent_id', agent_row.agent_id) )
#    except odb.eNoMatchingRows:
#      print "this is bad, we should have found the row"
#
#    # fetch rows (can return empty list)
#    list_rows = tbl.fetchRows( ('login', "foo") )
#

import string
import sys, zlib
from log import *

import handle_error

eNoSuchColumn         = "odb.eNoSuchColumn"
eNonUniqueMatchSpec   = "odb.eNonUniqueMatchSpec"
eNoMatchingRows       = "odb.eNoMatchingRows"
eInternalError        = "odb.eInternalError"
eInvalidMatchSpec     = "odb.eInvalidMatchSpec"
eInvalidData          = "odb.eInvalidData"
eUnsavedObjectLost    = "odb.eUnsavedObjectLost"
eDuplicateKey         = "odb.eDuplicateKey"

#####################################
# COLUMN TYPES                       
################                     ######################
# typename     ####################### size data means:
#              #                     # 
kInteger       = "kInteger"          # -
kFixedString   = "kFixedString"      # size
kVarString     = "kVarString"        # maxsize
kBigString     = "kBigString"        # -
kIncInteger    = "kIncInteger"       # -
kDateTime      = "kDateTime"
kTimeStamp     = "kTimeStamp"
kReal          = "kReal"


DEBUG = 0

##############
# Database
#
# this will ultimately turn into a mostly abstract base class for
# the DB adaptors for different database types....
#

class Database:
    def __init__(self, db, debug=0):
        self._tables = {}
        self.db = db
        self._cursor = None
        self.compression_enabled = 0
        self.debug = debug
        self.SQLError = None

	self.__defaultRowClass = self.defaultRowClass()
	self.__defaultRowListClass = self.defaultRowListClass()

    def defaultCursor(self):
        if self._cursor is None:
            self._cursor = self.db.cursor()
        return self._cursor

    def escape(self,str):
	raise "Unimplemented Error"

    def getDefaultRowClass(self): return self.__defaultRowClass
    def setDefaultRowClass(self, clss): self.__defaultRowClass = clss
    def getDefaultRowListClass(self): return self.__defaultRowListClass
    def setDefaultRowListClass(self, clss): self.__defaultRowListClass = clss

    def defaultRowClass(self):
        return Row

    def defaultRowListClass(self):
        # base type is list...
        return list

    def addTable(self, attrname, tblname, tblclass, 
                 rowClass = None, check = 0, create = 0, rowListClass = None):
        tbl = tblclass(self, tblname, rowClass=rowClass, check=check, 
                       create=create, rowListClass=rowListClass)
        self._tables[attrname] = tbl
        return tbl

    def close(self):
        for name, tbl in self._tables.items():
            tbl.db = None
        self._tables = {}
        if self.db is not None:
            self.db.close()
            self.db = None

    def __getattr__(self, key):
        if key == "_tables":
            raise AttributeError, "odb.Database: not initialized properly, self._tables does not exist"

        try:
            table_dict = getattr(self,"_tables")
            return table_dict[key]
        except KeyError:
            raise AttributeError, "odb.Database: unknown attribute %s" % (key)
        
    def beginTransaction(self, cursor=None):
        if cursor is None:
            cursor = self.defaultCursor()
        dlog(DEV_UPDATE,"begin")
        cursor.execute("begin")

    def commitTransaction(self, cursor=None):
        if cursor is None:
            cursor = self.defaultCursor()
        dlog(DEV_UPDATE,"commit")
        cursor.execute("commit")

    def rollbackTransaction(self, cursor=None):
        if cursor is None:
            cursor = self.defaultCursor()
        dlog(DEV_UPDATE,"rollback")
        cursor.execute("rollback")

    ## 
    ## schema creation code
    ##

    def createTables(self):
      tables = self.listTables()

      for attrname, tbl in self._tables.items():
        tblname = tbl.getTableName()

        if tblname not in tables:
          print "table %s does not exist" % tblname
          tbl.createTable()
        else:
          invalidAppCols, invalidDBCols = tbl.checkTable()

##          self.alterTableToMatch(tbl)

    def createIndices(self):
      indices = self.listIndices()

      for attrname, tbl in self._tables.items():
        for indexName, (columns, unique) in tbl.getIndices().items():
          if indexName in indices: continue

          tbl.createIndex(columns, indexName=indexName, unique=unique)

    def synchronizeSchema(self):
      tables = self.listTables()

      for attrname, tbl in self._tables.items():
        tblname = tbl.getTableName()
        self.alterTableToMatch(tbl)
        
    def listTables(self, cursor=None):
      raise "Unimplemented Error"

    def listFieldsDict(self, table_name, cursor=None):
      raise "Unimplemented Error"

    def listFields(self, table_name, cursor=None):
      columns = self.listFieldsDict(table_name, cursor=cursor)
      return columns.keys()

##########################################
# Table
#


class Table:
    def subclassinit(self):
        pass
    def __init__(self,database,table_name,
                 rowClass = None, check = 0, create = 0, rowListClass = None):
        self.db = database
        self.__table_name = table_name
        if rowClass:
            self.__defaultRowClass = rowClass
        else:
            self.__defaultRowClass = database.getDefaultRowClass()

        if rowListClass:
            self.__defaultRowListClass = rowListClass
        else:
            self.__defaultRowListClass = database.getDefaultRowListClass()

        # get this stuff ready!
        
        self.__column_list = []
        self.__vcolumn_list = []
        self.__columns_locked = 0
        self.__has_value_column = 0

        self.__indices = {}

        # this will be used during init...
        self.__col_def_hash = None
        self.__vcol_def_hash = None
        self.__primary_key_list = None
        self.__relations_by_table = {}

        # ask the subclass to def his rows
        self._defineRows()

        # get ready to run!
        self.__lockColumnsAndInit()

        self.subclassinit()
        
        if create:
            self.createTable()

        if check:
            self.checkTable()

    def _colTypeToSQLType(self, colname, coltype, options):
      
      if coltype == kInteger:
        coltype = "integer"
      elif coltype == kFixedString:
        sz = options.get('size', None)
        if sz is None: coltype = 'char'
        else:  coltype = "char(%s)" % sz
      elif coltype == kVarString:
        sz = options.get('size', None)
        if sz is None: coltype = 'varchar'
        else:  coltype = "varchar(%s)" % sz
      elif coltype == kBigString:
        coltype = "text"
      elif coltype == kIncInteger:
        coltype = "integer"
      elif coltype == kDateTime:
        coltype = "datetime"
      elif coltype == kTimeStamp:
        coltype = "timestamp"
      elif coltype == kReal:
        coltype = "real"

      coldef = "%s %s" % (colname, coltype)

      if options.get('notnull', 0): coldef = coldef + " NOT NULL"
      if options.get('autoincrement', 0): coldef = coldef + " AUTO_INCREMENT"
      if options.get('unique', 0): coldef = coldef + " UNIQUE"
#      if options.get('primarykey', 0): coldef = coldef + " primary key"
      if options.get('default', None) is not None: coldef = coldef + " DEFAULT %s" % options.get('default')

      return coldef

    def getTableName(self):  return self.__table_name
    def setTableName(self, tablename):  self.__table_name = tablename

    def getIndices(self): return self.__indices

    def _createTableSQL(self):
      defs = []
      for colname, coltype, options in self.__column_list:
        defs.append(self._colTypeToSQLType(colname, coltype, options))

      defs = string.join(defs, ", ")

      primarykeys = self.getPrimaryKeyList()
      primarykey_str = ""
      if primarykeys:
	primarykey_str = ", PRIMARY KEY (" + string.join(primarykeys, ",") + ")"

      sql = "create table %s (%s %s)" % (self.__table_name, defs, primarykey_str)
      return sql

    def createTable(self, cursor=None):
      if cursor is None: cursor = self.db.defaultCursor()
      sql = self._createTableSQL()
      print "CREATING TABLE:", sql
      cursor.execute(sql)

    def dropTable(self, cursor=None):
      if cursor is None: cursor = self.db.defaultCursor()
      try:
        cursor.execute("drop table %s" % self.__table_name)   # clean out the table
      except self.SQLError, reason:
        pass

    def renameTable(self, newTableName, cursor=None):
      if cursor is None: cursor = self.db.defaultCursor()
      try:
        cursor.execute("rename table %s to %s" % (self.__table_name, newTableName))
      except sel.SQLError, reason:
        pass

      self.setTableName(newTableName)
      
    def getTableColumnsFromDB(self):
      return self.db.listFieldsDict(self.__table_name)
      
    def checkTable(self, warnflag=1):
      invalidDBCols = {}
      invalidAppCols = {}

      dbcolumns = self.getTableColumnsFromDB()
      for coldef in self.__column_list:
        colname = coldef[0]

        dbcoldef = dbcolumns.get(colname, None)
        if dbcoldef is None:
          invalidAppCols[colname] = 1
      
      for colname, row in dbcolumns.items():
        coldef = self.__col_def_hash.get(colname, None)
        if coldef is None:
          invalidDBCols[colname] = 1

      if warnflag == 1:
        if invalidDBCols:
          print "----- WARNING ------------------------------------------"
          print "  There are columns defined in the database schema that do"
          print "  not match the application's schema."
          print "  columns:", invalidDBCols.keys()
          print "--------------------------------------------------------"

        if invalidAppCols: 
          print "----- WARNING ------------------------------------------"
          print "  There are new columns defined in the application schema"
          print "  that do not match the database's schema."
          print "  columns:", invalidAppCols.keys()
          print "--------------------------------------------------------"

      return invalidAppCols, invalidDBCols


    def alterTableToMatch(self):
      raise "Unimplemented Error!"

    def addIndex(self, columns, indexName=None, unique=0):
      if indexName is None:
        indexName = self.getTableName() + "_index_" + string.join(columns, "_")

      self.__indices[indexName] = (columns, unique)
      
    def createIndex(self, columns, indexName=None, unique=0, cursor=None):
      if cursor is None: cursor = self.db.defaultCursor()
      cols = string.join(columns, ",")

      if indexName is None:
        indexName = self.getTableName() + "_index_" + string.join(columns, "_")

      uniquesql = ""
      if unique:
        uniquesql = " unique"
      sql = "create %s index %s on %s (%s)" % (uniquesql, indexName, self.getTableName(), cols)
      warn("creating index", sql)
      cursor.execute(sql)


    ## Column Definition

    def getColumnDef(self,column_name):
        try:
            return self.__col_def_hash[column_name]
        except KeyError:
            try:
                return self.__vcol_def_hash[column_name]
            except KeyError:
                raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name)

    def getColumnList(self):  
      return self.__column_list + self.__vcolumn_list
    def getAppColumnList(self): return self.__column_list

    def databaseSizeForData_ColumnName_(self,data,col_name):
        try:
            col_def = self.__col_def_hash[col_name]
        except KeyError:
            try:
                col_def = self.__vcol_def_hash[col_name]
            except KeyError:
                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)

        c_name,c_type,c_options = col_def

        if c_type == kBigString:
            if c_options.get("compress_ok",0) and self.db.compression_enabled:
                z_size = len(zlib.compress(data,9))
                r_size = len(data)
                if z_size < r_size:
                    return z_size
                else:
                    return r_size
            else:
                return len(data)
        else:
            # really simplistic database size computation:
            try:
                a = data[0]
                return len(data)
            except:
                return 4
            

    def columnType(self, col_name):
        try:
            col_def = self.__col_def_hash[col_name]
        except KeyError:
            try:
                col_def = self.__vcol_def_hash[col_name]
            except KeyError:
                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)

        c_name,c_type,c_options = col_def
        return c_type

    def convertDataForColumn(self,data,col_name):
        try:
            col_def = self.__col_def_hash[col_name]
        except KeyError:
            try:
                col_def = self.__vcol_def_hash[col_name]
            except KeyError:
                raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)

        c_name,c_type,c_options = col_def

        if c_type == kIncInteger:
            raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name)

        if c_type == kInteger:
            try:
                if data is None: data = 0
                else: return long(data)
            except (ValueError,TypeError):
                raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name)
        elif c_type == kReal:
            try:
                if data is None: data = 0.0
                else: return float(data)
            except (ValueError,TypeError):
                raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data), col_name,c_type,self.__table_name)

        else:
            if type(data) == type(long(0)):
                return "%d" % data
            else:
                return str(data)

    def getPrimaryKeyList(self):
        return self.__primary_key_list
    
    def hasValueColumn(self):
        return self.__has_value_column

    def hasColumn(self,name):
        return self.__col_def_hash.has_key(name)
    def hasVColumn(self,name):
        return self.__vcol_def_hash.has_key(name)
        

    def _defineRows(self):
        raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()"

    def __lockColumnsAndInit(self):
        # add a 'odb_value column' before we lockdown the table def
        if self.__has_value_column:
            self.d_addColumn("odb_value",kBigText,default='')

        self.__columns_locked = 1
        # walk column list and make lookup hashes, primary_key_list, etc..

        primary_key_list = []
        col_def_hash = {}
        for a_col in self.__column_list:
            name,type,options = a_col
            col_def_hash[name] = a_col
            if options.has_key('primarykey'):
                primary_key_list.append(name)

        self.__col_def_hash = col_def_hash
        self.__primary_key_list = primary_key_list

        # setup the value columns!

        if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0):
            raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()"

        vcol_def_hash = {}
        for a_col in self.__vcolumn_list:
            name,type,size_data,options = a_col
            vcol_def_hash[name] = a_col

        self.__vcol_def_hash = vcol_def_hash
        
        
    def __checkColumnLock(self):
        if self.__columns_locked:
            raise "can't change column definitions outside of subclass' _defineRows() method!"

    # table definition methods, these are only available while inside the
    # subclass's _defineRows method
    #
    # Ex:
    #
    # import odb
    # class MyTable(odb.Table):
    #   def _defineRows(self):
    #     self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1)
    #     self.d_addColumn("name",kVarString,120)
    #     self.d_addColumn("type",kInteger,
    #                      enum_values = { 0 : "alive", 1 : "dead" }

    def d_addColumn(self,col_name,ctype,size=None,primarykey = 0, 
                    notnull = 0,indexed=0,
                    default=None,unique=0,autoincrement=0,safeupdate=0,
                    enum_values = None,
		    no_export = 0,
                    relations=None,compress_ok=0,int_date=0):

        self.__checkColumnLock()

        options = {}
        options['default']       = default
        if primarykey:
            options['primarykey']    = primarykey
        if unique:
            options['unique']        = unique
        if indexed:
            options['indexed']       = indexed
            self.addIndex((col_name,))
        if safeupdate:
            options['safeupdate']    = safeupdate
        if autoincrement:
            options['autoincrement'] = autoincrement
        if notnull:
            options['notnull']       = notnull
        if size:
            options['size']          = size
        if no_export:
            options['no_export']     = no_export
        if int_date:
            if ctype != kInteger:
                raise eInvalidData, "can't flag columns int_date unless they are kInteger"
            else:
                options['int_date'] = int_date
            
        if enum_values:
            options['enum_values']   = enum_values
            inv_enum_values = {}
            for k,v in enum_values.items():
                if inv_enum_values.has_key(v):
                    raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name
                else:
                    inv_enum_values[v] = k
            options['inv_enum_values'] = inv_enum_values
        if relations:
            options['relations']      = relations
            for a_relation in relations:
                table, foreign_column_name = a_relation
                if self.__relations_by_table.has_key(table):
                    raise eInvalidData, "multiple relations for the same foreign table are not yet supported" 
                self.__relations_by_table[table] = (col_name,foreign_column_name)
        if compress_ok:
            if ctype == kBigString:
                options['compress_ok'] = 1
            else:
                raise eInvalidData, "only kBigString fields can be compress_ok=1"
        
        self.__column_list.append( (col_name,ctype,options) )

    def d_addValueColumn(self):
        self.__checkColumnLock()
        self.__has_value_column = 1

    def d_addVColumn(self,col_name,type,size=None,default=None):
        self.__checkColumnLock()

        if (not self.__has_value_column):
            raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first"

        options = {}
        if default:
            options['default'] = default
        if size:
            options['size']    = size

        self.__vcolumn_list.append( (col_name,type,options) )

    #####################
    # _checkColMatchSpec(col_match_spec,should_match_unique_row = 0)
    #
    # raise an error if the col_match_spec contains invalid columns, or
    # (in the case of should_match_unique_row) if it does not fully specify
    # a unique row.
    #
    # NOTE: we don't currently support where clauses with value column fields!
    #
    
    def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0):
        if type(col_match_spec) == type([]):
            if type(col_match_spec[0]) != type((0,)):
                raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
        elif type(col_match_spec) == type((0,)):
            col_match_spec = [ col_match_spec ]
        elif type(col_match_spec) == type(None):
            if should_match_unique_row:
                raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
            else:
                return None
        else:
            raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"

        if should_match_unique_row:
            unique_column_lists = []

            # first the primary key list
            my_primary_key_list = []
            for a_key in self.__primary_key_list:
                my_primary_key_list.append(a_key)

            # then other unique keys
            for a_col in self.__column_list:
                col_name,a_type,options = a_col
                if options.has_key('unique'):
                    unique_column_lists.append( (col_name, [col_name]) )

            unique_column_lists.append( ('primary_key', my_primary_key_list) )
                
        
        new_col_match_spec = []
        for a_col in col_match_spec:
            name,val = a_col
            # newname = string.lower(name)
            #  what is this doing?? - jeske
            newname = name
            if not self.__col_def_hash.has_key(newname):
                raise eNoSuchColumn, "no such column in match spec: '%s'" % newname

            new_col_match_spec.append( (newname,val) )

            if should_match_unique_row:
                for name,a_list in unique_column_lists:
                    try:
                        a_list.remove(newname)
                    except ValueError:
                        # it's okay if they specify too many columns!
                        pass

        if should_match_unique_row:
            for name,a_list in unique_column_lists:
                if len(a_list) == 0:
                    # we matched at least one unique colum spec!
                    # log("using unique column (%s) for query %s" % (name,col_match_spec))
                    return new_col_match_spec
            
            raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec

        return new_col_match_spec

    def __buildWhereClause (self, col_match_spec,other_clauses = None):
        sql_where_list = []

        if not col_match_spec is None:
            for m_col in col_match_spec:
                m_col_name,m_col_val = m_col
                c_name,c_type,c_options = self.__col_def_hash[m_col_name]
                if c_type in (kIncInteger, kInteger):
                    try:
                        m_col_val_long = long(m_col_val)
                    except ValueError:
                        raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name)
                        
                    sql_where_list.append("%s = %d" % (c_name, m_col_val_long))
                elif c_type == kReal:
                    try:
                        m_col_val_float = float(m_col_val)
                    except ValueError:
                        raise ValueError, "invalid literal for float(%s) is table %s" % (repr(m_col_val), self.__table_name)
                    sql_where_list.append("%s = %s" % (c_name, m_col_val_float))
                else:
                    sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val)))

        if other_clauses is None:
            pass
        elif type(other_clauses) == type(""):
            sql_where_list = sql_where_list + [other_clauses]
        elif type(other_clauses) == type([]):
            sql_where_list = sql_where_list + other_clauses
        else:
            raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses)
                    
        return sql_where_list

    def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None,
                    skip_to = None, join = None):
        if cursor is None:
            cursor = self.db.defaultCursor()

        # build column list
        sql_columns = []
        for name,t,options in self.__column_list:
            sql_columns.append(name)

        # build join information

        joined_cols = []
        joined_cols_hash = {}
        join_clauses = []
        if not join is None:
            for a_table,retrieve_foreign_cols in join:
                try:
                    my_col,foreign_col = self.__relations_by_table[a_table]
                    for a_col in retrieve_foreign_cols:
                        full_col_name = "%s.%s" % (my_col,a_col)
                        joined_cols_hash[full_col_name] = 1
                        joined_cols.append(full_col_name)
                        sql_columns.append( full_col_name )

                    join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col))
                        
                except KeyError:
                    eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name)
                    
        # start buildling SQL
        sql = "select %s from %s" % (string.join(sql_columns,","),
                                     self.__table_name)

        # add join clause
        if join_clauses:
            sql = sql + string.join(join_clauses," ")
        
        # add where clause elements
        sql_where_list = self.__buildWhereClause (col_match_spec,where)
        if sql_where_list:
            sql = sql + " where %s" % (string.join(sql_where_list," and "))

        # add order by clause
        if order_by:
            sql = sql + " order by %s " % string.join(order_by,",")

        # add limit
        if not limit_to is None:
            if not skip_to is None:
#                log("limit,skip = %s,%s" % (limit_to,skip_to))
                if self.db.db.__module__ == "sqlite.main":
                    sql = sql + " limit %s offset %s " % (limit_to,skip_to)
                else:
                    sql = sql + " limit %s, %s" % (skip_to,limit_to)
            else:
                sql = sql + " limit %s" % limit_to
        else:
            if not skip_to is None:
                raise eInvalidData, "can't specify skip_to without limit_to in MySQL"

        dlog(DEV_SELECT,sql)
        cursor.execute(sql)

        # create defaultRowListClass instance...
        return_rows = self.__defaultRowListClass()
            
        # should do fetchmany!
        all_rows = cursor.fetchall()
        for a_row in all_rows:
            data_dict = {}

            col_num = 0
            
            #            for a_col in cursor.description:
            #                (name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col
            for name in sql_columns:
                if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name):
                    # only include declared columns!
                    if self.__col_def_hash.has_key(name):
                        c_name,c_type,c_options = self.__col_def_hash[name]
                        if c_type == kBigString and c_options.get("compress_ok",0) and a_row[col_num]:
                            try:
                                a_col_data = zlib.decompress(a_row[col_num])
                            except zlib.error:
                                a_col_data = a_row[col_num]

                            data_dict[name] = a_col_data
                        elif c_type == kInteger or c_type == kIncInteger:
                            value = a_row[col_num]
                            if not value is None:
                                data_dict[name] = int(value)
                            else:
                                data_dict[name] = None
                        else:
                            data_dict[name] = a_row[col_num]

                    else:
                        data_dict[name] = a_row[col_num]
                        
                    col_num = col_num + 1

	    newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols)
	    return_rows.append(newrowobj)
	      

            
        return return_rows

    def __deleteRow(self,a_row,cursor = None):
        if cursor is None:
            cursor = self.db.defaultCursor()

        # build the where clause!
        match_spec = a_row.getPKMatchSpec()
        sql_where_list = self.__buildWhereClause (match_spec)

        sql = "delete from %s where %s" % (self.__table_name,
                                           string.join(sql_where_list," and "))
        dlog(DEV_UPDATE,sql)
        cursor.execute(sql)
       

    def __updateRowList(self,a_row_list,cursor = None):
        if cursor is None:
            cursor = self.db.defaultCursor()

        for a_row in a_row_list:
            update_list = a_row.changedList()

            # build the set list!
            sql_set_list = []
            for a_change in update_list:
                col_name,col_val,col_inc_val = a_change
                c_name,c_type,c_options = self.__col_def_hash[col_name]

                if c_type != kIncInteger and col_val is None:
                    sql_set_list.append("%s = NULL" % c_name)
                elif c_type == kIncInteger and col_inc_val is None:
                    sql_set_list.append("%s = 0" % c_name)
                else:
                    if c_type == kInteger:
                        sql_set_list.append("%s = %d" % (c_name, long(col_val)))
                    elif c_type == kIncInteger:
                        sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val)))
                    elif c_type == kBigString and c_options.get("compress_ok",0) and self.db.compression_enabled:
                        compressed_data = zlib.compress(col_val,9)
                        if len(compressed_data) < len(col_val):
                            sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(compressed_data)))
                        else:
                            sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))
                    elif c_type == kReal:
                        sql_set_list.append("%s = %s" % (c_name,float(col_val)))

                    else:
                        sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))

            # build the where clause!
            match_spec = a_row.getPKMatchSpec()
            sql_where_list = self.__buildWhereClause (match_spec)

            if sql_set_list:
                sql = "update %s set %s where %s" % (self.__table_name,
                                                 string.join(sql_set_list,","),
                                                 string.join(sql_where_list," and "))

                dlog(DEV_UPDATE,sql)
                try:
                    cursor.execute(sql)
                except Exception, reason:
                    if string.find(str(reason), "Duplicate entry") != -1:
                        raise eDuplicateKey, reason
                    raise Exception, reason
                a_row.markClean()

    def __insertRow(self,a_row_obj,cursor = None,replace=0):
        if cursor is None:
            cursor = self.db.defaultCursor()

        sql_col_list = []
        sql_data_list = []
        auto_increment_column_name = None

        for a_col in self.__column_list:
            name,type,options = a_col

            try:
                data = a_row_obj[name]

                sql_col_list.append(name)
                if data is None:
                    sql_data_list.append("NULL")
                else:
                    if type == kInteger or type == kIncInteger:
                        sql_data_list.append("%d" % data)
                    elif type == kBigString and options.get("compress_ok",0) and self.db.compression_enabled:
                        compressed_data = zlib.compress(data,9)
                        if len(compressed_data) < len(data):
                            sql_data_list.append("'%s'" % self.db.escape(compressed_data))
                        else:
                            sql_data_list.append("'%s'" % self.db.escape(data))
                    elif type == kReal:
                        sql_data_list.append("%s" % data)
                    else:
                        sql_data_list.append("'%s'" % self.db.escape(data))

            except KeyError:
                if options.has_key("autoincrement"):
                    if auto_increment_column_name:
                        raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name)
                    else:
                        auto_increment_column_name = name

        if replace:
            sql = "replace into %s (%s) values (%s)" % (self.__table_name,
                                                   string.join(sql_col_list,","),
                                                   string.join(sql_data_list,","))
        else:
            sql = "insert into %s (%s) values (%s)" % (self.__table_name,
                                                   string.join(sql_col_list,","),
                                                   string.join(sql_data_list,","))

        dlog(DEV_UPDATE,sql)
        try:
          cursor.execute(sql)
        except Exception, reason:
          # sys.stderr.write("errror in statement: " + sql + "\n")
          log("error in statement: " + sql + "\n")
          if string.find(str(reason), "Duplicate entry") != -1:
            raise eDuplicateKey, reason
          raise Exception, reason
            
        if auto_increment_column_name:
            if cursor.__module__ == "sqlite.main":
                a_row_obj[auto_increment_column_name] = cursor.lastrowid
            elif cursor.__module__ == "MySQLdb.cursors":
                a_row_obj[auto_increment_column_name] = cursor.insert_id()
            else:
                # fallback to acting like mysql
                a_row_obj[auto_increment_column_name] = cursor.insert_id()

    # ----------------------------------------------------
    #   Helper methods for Rows...
    # ----------------------------------------------------


        
    #####################
    # r_deleteRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "delete()" method
    # but you can call it yourself if you want
    #

    def r_deleteRow(self,a_row_obj, cursor = None):
        curs = cursor
        self.__deleteRow(a_row_obj, cursor = curs)


    #####################
    # r_updateRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "save()" method
    # but you can call it yourself if you want
    #

    def r_updateRow(self,a_row_obj, cursor = None):
        curs = cursor
        self.__updateRowList([a_row_obj], cursor = curs)

    #####################
    # InsertRow(a_row_obj,cursor = None)
    #
    # normally this is called from within the Row "save()" method
    # but you can call it yourself if you want
    #

    def r_insertRow(self,a_row_obj, cursor = None,replace=0):
        curs = cursor
        self.__insertRow(a_row_obj, cursor = curs,replace=replace)


    # ----------------------------------------------------
    #   Public Methods
    # ----------------------------------------------------


        
    #####################
    # deleteRow(col_match_spec)
    #
    # The col_match_spec paramaters must include all primary key columns.
    #
    # Ex:
    #    a_row = tbl.fetchRow( ("order_id", 1) )
    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )


    def deleteRow(self,col_match_spec, where=None):
        n_match_spec = self._fixColMatchSpec(col_match_spec)
        cursor = self.db.defaultCursor()

        # build sql where clause elements
        sql_where_list = self.__buildWhereClause (n_match_spec,where)
        if not sql_where_list:
            return

        sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and "))

        dlog(DEV_UPDATE,sql)
        cursor.execute(sql)
        
    #####################
    # fetchRow(col_match_spec)
    #
    # The col_match_spec paramaters must include all primary key columns.
    #
    # Ex:
    #    a_row = tbl.fetchRow( ("order_id", 1) )
    #    a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )


    def fetchRow(self, col_match_spec, cursor = None):
        n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1)

        rows = self.__fetchRows(n_match_spec, cursor = cursor)
        if len(rows) == 0:
            raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec)

        if len(rows) > 1:
            raise eInternalError, "unique where clause shouldn't return > 1 row"

        return rows[0]
            

    #####################
    # fetchRows(col_match_spec)
    #
    # Ex:
    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )


    def fetchRows(self, col_match_spec = None, cursor = None, 
		  where = None, order_by = None, limit_to = None, 
		  skip_to = None, join = None):
        n_match_spec = self._fixColMatchSpec(col_match_spec)

        return self.__fetchRows(n_match_spec,
                                cursor = cursor,
                                where = where,
                                order_by = order_by,
                                limit_to = limit_to,
                                skip_to = skip_to,
                                join = join)

    def fetchRowCount (self, col_match_spec = None, 
		       cursor = None, where = None):
        n_match_spec = self._fixColMatchSpec(col_match_spec)
        sql_where_list = self.__buildWhereClause (n_match_spec,where)
	sql = "select count(*) from %s" % self.__table_name
        if sql_where_list:
            sql = "%s where %s" % (sql,string.join(sql_where_list," and "))
        if cursor is None:
          cursor = self.db.defaultCursor()
        dlog(DEV_SELECT,sql)
        cursor.execute(sql)
        try:
            count, = cursor.fetchone()
        except TypeError:
            count = 0
        return count


    #####################
    # fetchAllRows()
    #
    # Ex:
    #    a_row_list = tbl.fetchRows( ("order_id", 1) )
    #    a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )

    def fetchAllRows(self):
        try:
            return self.__fetchRows([])
        except eNoMatchingRows:
            # else return empty list...
            return self.__defaultRowListClass()

    def newRow(self,replace=0):
        row = self.__defaultRowClass(self,None,create=1,replace=replace)
        for (cname, ctype, opts) in self.__column_list:
            if opts['default'] is not None and ctype is not kIncInteger:
                row[cname] = opts['default']
        return row

class Row:
    __instance_data_locked  = 0
    def subclassinit(self):
        pass
    def __init__(self,_table,data_dict,create=0,joined_cols = None,replace=0):

        self._inside_getattr = 0  # stop recursive __getattr__
        self._table = _table
        self._should_insert = create or replace
        self._should_replace = replace
        self._rowInactive = None
        self._joinedRows = []
        
        self.__pk_match_spec = None
        self.__vcoldata = {}
        self.__inc_coldata = {}

        self.__joined_cols_dict = {}
        for a_col in joined_cols or []:
            self.__joined_cols_dict[a_col] = 1
        
        if create:
            self.__coldata = {}
        else:
            if type(data_dict) != type({}):
                raise eInternalError, "rowdict instantiate with bad data_dict"
            self.__coldata = data_dict
            self.__unpackVColumn()

        self.markClean()

        self.subclassinit()
        self.__instance_data_locked = 1

    def joinRowData(self,another_row):
        self._joinedRows.append(another_row)

    def getPKMatchSpec(self):
        return self.__pk_match_spec

    def markClean(self):
        self.__vcolchanged = 0
        self.__colchanged_dict = {}

        for key in self.__inc_coldata.keys():
            self.__coldata[key] = self.__coldata.get(key, 0) + self.__inc_coldata[key]

        self.__inc_coldata = {}

        if not self._should_insert:
            # rebuild primary column match spec
            new_match_spec = []
            for col_name in self._table.getPrimaryKeyList():
                try:
                    rdata = self[col_name]
                except KeyError:
                    raise eInternalError, "must have primary key data filled in to save %s:Row(col:%s)" % (self._table.getTableName(),col_name)
                    
                new_match_spec.append( (col_name, rdata) )
            self.__pk_match_spec = new_match_spec

    def __unpackVColumn(self):
        if self._table.hasValueColumn():
            pass
        
    def __packVColumn(self):
        if self._table.hasValueColumn():
            pass

    ## ----- utility stuff ----------------------------------

    def __del__(self):
        # check for unsaved changes
        changed_list = self.changedList()
        if len(changed_list):
            info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256])
            if 0:
                raise eUnsavedObjectLost, info
            else:
                sys.stderr.write(info)
                

    def __repr__(self):
        return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata))

    ## ---- class emulation --------------------------------

    def __getattr__(self,key):
        if self._inside_getattr:
          raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName())
        try:
            self._inside_getattr = 1
            try:
                return self[key]
            except KeyError:
                if self._table.hasColumn(key) or self._table.hasVColumn(key):
                    return None
                else:
                    raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
        finally:
            self._inside_getattr = 0

    def __setattr__(self,key,val):
        if not self.__instance_data_locked:
            self.__dict__[key] = val
        else:
            my_dict = self.__dict__
            if my_dict.has_key(key):
                my_dict[key] = val
            else:
                # try and put it into the rowdata
                try:
                    self[key] = val
                except KeyError, reason:
                    raise AttributeError, reason


    ## ---- dict emulation ---------------------------------
    
    def __getitem__(self,key):
        self.checkRowActive()

        try:
            c_type = self._table.columnType(key)
        except eNoSuchColumn, reason:
            # Ugh, this sucks, we can't determine the type for a joined
            # row, so we just default to kVarString and let the code below
            # determine if this is a joined column or not
            c_type = kVarString

        if c_type == kIncInteger:
            c_data = self.__coldata.get(key, 0) 
            if c_data is None: c_data = 0
            i_data = self.__inc_coldata.get(key, 0)
            if i_data is None: i_data = 0
            return c_data + i_data
        
        try:
            return self.__coldata[key]
        except KeyError:
            try:
                return self.__vcoldata[key]
            except KeyError:
                for a_joined_row in self._joinedRows:
                    try:
                        return a_joined_row[key]
                    except KeyError:
                        pass

                raise KeyError, "unknown column %s in %s" % (key,self)

    def __setitem__(self,key,data):
        self.checkRowActive()
        
        try:
            newdata = self._table.convertDataForColumn(data,key)
        except eNoSuchColumn, reason:
            raise KeyError, reason

        if self._table.hasColumn(key):
            self.__coldata[key] = newdata
            self.__colchanged_dict[key] = 1
        elif self._table.hasVColumn(key):
            self.__vcoldata[key] = newdata
            self.__vcolchanged = 1
        else:
            for a_joined_row in self._joinedRows:
                try:
                    a_joined_row[key] = data
                    return
                except KeyError:
                    pass
            raise KeyError, "unknown column name %s" % key

    def __delitem__(self,key,data):
        self.checkRowActive()
        
        if self.table.hasVColumn(key):
            del self.__vcoldata[key]
        else:
            for a_joined_row in self._joinedRows:
                try:
                    del a_joined_row[key]
                    return
                except KeyError:
                    pass
            raise KeyError, "unknown column name %s" % key


    def copyFrom(self,source):
        for name,t,options in self._table.getColumnList():
            if not options.has_key("autoincrement"):
                self[name] = source[name]


    # make sure that .keys(), and .items() come out in a nice order!

    def keys(self):
        self.checkRowActive()
        
        key_list = []
        for name,t,options in self._table.getColumnList():
            key_list.append(name)
        for name in self.__joined_cols_dict.keys():
            key_list.append(name)

        for a_joined_row in self._joinedRows:
            key_list = key_list + a_joined_row.keys()
            
        return key_list


    def items(self):
        self.checkRowActive()
        
        item_list = []
        for name,t,options in self._table.getColumnList():
            item_list.append( (name,self[name]) )
        for name in self.__joined_cols_dict.keys():
            item_list.append( (name,self[name]) )

        for a_joined_row in self._joinedRows:
            item_list = item_list + a_joined_row.items()


        return item_list

    def values(elf):
        self.checkRowActive()

        value_list = self.__coldata.values() + self.__vcoldata.values()

        for a_joined_row in self._joinedRows:
            value_list = value_list + a_joined_row.values()

        return value_list
        

    def __len__(self):
        self.checkRowActive()
        
        my_len = len(self.__coldata) + len(self.__vcoldata)

        for a_joined_row in self._joinedRows:
            my_len = my_len + len(a_joined_row)

        return my_len

    def has_key(self,key):
        self.checkRowActive()
        
        if self.__coldata.has_key(key) or self.__vcoldata.has_key(key):
            return 1
        else:

            for a_joined_row in self._joinedRows:
                if a_joined_row.has_key(key):
                    return 1
            return 0
        
    def get(self,key,default = None):
        self.checkRowActive()

        
        
        if self.__coldata.has_key(key):
            return self.__coldata[key]
        elif self.__vcoldata.has_key(key):
            return self.__vcoldata[key]
        else:
            for a_joined_row in self._joinedRows:
                try:
                    return a_joined_row.get(key,default)
                except eNoSuchColumn:
                    pass

            if self._table.hasColumn(key):
                return default
            
            raise eNoSuchColumn, "no such column %s" % key

    def inc(self,key,count=1):
        self.checkRowActive()

        if self._table.hasColumn(key):
            try:
                self.__inc_coldata[key] = self.__inc_coldata[key] + count
            except KeyError:
                self.__inc_coldata[key] = count

            self.__colchanged_dict[key] = 1
        else:
            raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
    

    ## ----------------------------------
    ## real interface


    def fillDefaults(self):
        for field_def in self._table.fieldList():
            name,type,size,options = field_def
            if options.has_key("default"):
                self[name] = options["default"]

    ###############
    # changedList()
    #
    # returns a list of tuples for the columns which have changed
    #
    #   changedList() -> [ ('name', 'fred'), ('age', 20) ]

    def changedList(self):
        if self.__vcolchanged:
            self.__packVColumn()

        changed_list = []
        for a_col in self.__colchanged_dict.keys():
            changed_list.append( (a_col,self.get(a_col,None),self.__inc_coldata.get(a_col,None)) )

        return changed_list

    def discard(self):
        self.__coldata = None
        self.__vcoldata = None
        self.__colchanged_dict = {}
        self.__vcolchanged = 0

    def delete(self,cursor = None):
        self.checkRowActive()

        
        fromTable = self._table
        curs = cursor
        fromTable.r_deleteRow(self,cursor=curs)
        self._rowInactive = "deleted"

    def save(self,cursor = None):
        toTable = self._table

        self.checkRowActive()

        if self._should_insert:
            toTable.r_insertRow(self,replace=self._should_replace)
            self._should_insert = 0
            self._should_replace = 0
            self.markClean()  # rebuild the primary key list
        else:
            curs = cursor
            toTable.r_updateRow(self,cursor = curs)

        # the table will mark us clean!
        # self.markClean()

    def checkRowActive(self):
        if self._rowInactive:
            raise eInvalidData, "row is inactive: %s" % self._rowInactive

    def databaseSizeForColumn(self,key):
        return self._table.databaseSizeForData_ColumnName_(self[key],key)


if __name__ == "__main__":
    print "run odb_test.py"