#!/usr/bin/env python
#
# odb.py
#
# Object Database Api
#
# Written by David Jeske <jeske@neotonic.com>, 2001/07.
# Inspired by eGroups' sqldb.py originally written by Scott Hassan circa 1998.
#
# Copyright (C) 2001, by David Jeske and Neotonic
#
# Goals:
# - a simple object-like interface to database data
# - database independent (someday)
# - relational-style "rigid schema definition"
# - object style easy-access
#
# Example:
#
# import odb
#
# # define table
# class AgentsTable(odb.Table):
# def _defineRows(self):
# self.d_addColumn("agent_id",kInteger,None,primarykey = 1,autoincrement = 1)
# self.d_addColumn("login",kVarString,200,notnull=1)
# self.d_addColumn("ticket_count",kIncInteger,None)
#
# if __name__ == "__main__":
# # open database
# ndb = MySQLdb.connect(host = 'localhost',
# user='username',
# passwd = 'password',
# db='testdb')
# db = Database(ndb)
# tbl = AgentsTable(db,"agents")
#
# # create row
# agent_row = tbl.newRow()
# agent_row.login = "foo"
# agent_row.save()
#
# # fetch row (must use primary key)
# try:
# get_row = tbl.fetchRow( ('agent_id', agent_row.agent_id) )
# except odb.eNoMatchingRows:
# print "this is bad, we should have found the row"
#
# # fetch rows (can return empty list)
# list_rows = tbl.fetchRows( ('login', "foo") )
#
import string
import sys, zlib
from log import *
import handle_error
eNoSuchColumn = "odb.eNoSuchColumn"
eNonUniqueMatchSpec = "odb.eNonUniqueMatchSpec"
eNoMatchingRows = "odb.eNoMatchingRows"
eInternalError = "odb.eInternalError"
eInvalidMatchSpec = "odb.eInvalidMatchSpec"
eInvalidData = "odb.eInvalidData"
eUnsavedObjectLost = "odb.eUnsavedObjectLost"
eDuplicateKey = "odb.eDuplicateKey"
#####################################
# COLUMN TYPES
################ ######################
# typename ####################### size data means:
# # #
kInteger = "kInteger" # -
kFixedString = "kFixedString" # size
kVarString = "kVarString" # maxsize
kBigString = "kBigString" # -
kIncInteger = "kIncInteger" # -
kDateTime = "kDateTime"
kTimeStamp = "kTimeStamp"
kReal = "kReal"
DEBUG = 0
##############
# Database
#
# this will ultimately turn into a mostly abstract base class for
# the DB adaptors for different database types....
#
class Database:
def __init__(self, db, debug=0):
self._tables = {}
self.db = db
self._cursor = None
self.compression_enabled = 0
self.debug = debug
self.SQLError = None
self.__defaultRowClass = self.defaultRowClass()
self.__defaultRowListClass = self.defaultRowListClass()
def defaultCursor(self):
if self._cursor is None:
self._cursor = self.db.cursor()
return self._cursor
def escape(self,str):
raise "Unimplemented Error"
def getDefaultRowClass(self): return self.__defaultRowClass
def setDefaultRowClass(self, clss): self.__defaultRowClass = clss
def getDefaultRowListClass(self): return self.__defaultRowListClass
def setDefaultRowListClass(self, clss): self.__defaultRowListClass = clss
def defaultRowClass(self):
return Row
def defaultRowListClass(self):
# base type is list...
return list
def addTable(self, attrname, tblname, tblclass,
rowClass = None, check = 0, create = 0, rowListClass = None):
tbl = tblclass(self, tblname, rowClass=rowClass, check=check,
create=create, rowListClass=rowListClass)
self._tables[attrname] = tbl
return tbl
def close(self):
for name, tbl in self._tables.items():
tbl.db = None
self._tables = {}
if self.db is not None:
self.db.close()
self.db = None
def __getattr__(self, key):
if key == "_tables":
raise AttributeError, "odb.Database: not initialized properly, self._tables does not exist"
try:
table_dict = getattr(self,"_tables")
return table_dict[key]
except KeyError:
raise AttributeError, "odb.Database: unknown attribute %s" % (key)
def beginTransaction(self, cursor=None):
if cursor is None:
cursor = self.defaultCursor()
dlog(DEV_UPDATE,"begin")
cursor.execute("begin")
def commitTransaction(self, cursor=None):
if cursor is None:
cursor = self.defaultCursor()
dlog(DEV_UPDATE,"commit")
cursor.execute("commit")
def rollbackTransaction(self, cursor=None):
if cursor is None:
cursor = self.defaultCursor()
dlog(DEV_UPDATE,"rollback")
cursor.execute("rollback")
##
## schema creation code
##
def createTables(self):
tables = self.listTables()
for attrname, tbl in self._tables.items():
tblname = tbl.getTableName()
if tblname not in tables:
print "table %s does not exist" % tblname
tbl.createTable()
else:
invalidAppCols, invalidDBCols = tbl.checkTable()
## self.alterTableToMatch(tbl)
def createIndices(self):
indices = self.listIndices()
for attrname, tbl in self._tables.items():
for indexName, (columns, unique) in tbl.getIndices().items():
if indexName in indices: continue
tbl.createIndex(columns, indexName=indexName, unique=unique)
def synchronizeSchema(self):
tables = self.listTables()
for attrname, tbl in self._tables.items():
tblname = tbl.getTableName()
self.alterTableToMatch(tbl)
def listTables(self, cursor=None):
raise "Unimplemented Error"
def listFieldsDict(self, table_name, cursor=None):
raise "Unimplemented Error"
def listFields(self, table_name, cursor=None):
columns = self.listFieldsDict(table_name, cursor=cursor)
return columns.keys()
##########################################
# Table
#
class Table:
def subclassinit(self):
pass
def __init__(self,database,table_name,
rowClass = None, check = 0, create = 0, rowListClass = None):
self.db = database
self.__table_name = table_name
if rowClass:
self.__defaultRowClass = rowClass
else:
self.__defaultRowClass = database.getDefaultRowClass()
if rowListClass:
self.__defaultRowListClass = rowListClass
else:
self.__defaultRowListClass = database.getDefaultRowListClass()
# get this stuff ready!
self.__column_list = []
self.__vcolumn_list = []
self.__columns_locked = 0
self.__has_value_column = 0
self.__indices = {}
# this will be used during init...
self.__col_def_hash = None
self.__vcol_def_hash = None
self.__primary_key_list = None
self.__relations_by_table = {}
# ask the subclass to def his rows
self._defineRows()
# get ready to run!
self.__lockColumnsAndInit()
self.subclassinit()
if create:
self.createTable()
if check:
self.checkTable()
def _colTypeToSQLType(self, colname, coltype, options):
if coltype == kInteger:
coltype = "integer"
elif coltype == kFixedString:
sz = options.get('size', None)
if sz is None: coltype = 'char'
else: coltype = "char(%s)" % sz
elif coltype == kVarString:
sz = options.get('size', None)
if sz is None: coltype = 'varchar'
else: coltype = "varchar(%s)" % sz
elif coltype == kBigString:
coltype = "text"
elif coltype == kIncInteger:
coltype = "integer"
elif coltype == kDateTime:
coltype = "datetime"
elif coltype == kTimeStamp:
coltype = "timestamp"
elif coltype == kReal:
coltype = "real"
coldef = "%s %s" % (colname, coltype)
if options.get('notnull', 0): coldef = coldef + " NOT NULL"
if options.get('autoincrement', 0): coldef = coldef + " AUTO_INCREMENT"
if options.get('unique', 0): coldef = coldef + " UNIQUE"
# if options.get('primarykey', 0): coldef = coldef + " primary key"
if options.get('default', None) is not None: coldef = coldef + " DEFAULT %s" % options.get('default')
return coldef
def getTableName(self): return self.__table_name
def setTableName(self, tablename): self.__table_name = tablename
def getIndices(self): return self.__indices
def _createTableSQL(self):
defs = []
for colname, coltype, options in self.__column_list:
defs.append(self._colTypeToSQLType(colname, coltype, options))
defs = string.join(defs, ", ")
primarykeys = self.getPrimaryKeyList()
primarykey_str = ""
if primarykeys:
primarykey_str = ", PRIMARY KEY (" + string.join(primarykeys, ",") + ")"
sql = "create table %s (%s %s)" % (self.__table_name, defs, primarykey_str)
return sql
def createTable(self, cursor=None):
if cursor is None: cursor = self.db.defaultCursor()
sql = self._createTableSQL()
print "CREATING TABLE:", sql
cursor.execute(sql)
def dropTable(self, cursor=None):
if cursor is None: cursor = self.db.defaultCursor()
try:
cursor.execute("drop table %s" % self.__table_name) # clean out the table
except self.SQLError, reason:
pass
def renameTable(self, newTableName, cursor=None):
if cursor is None: cursor = self.db.defaultCursor()
try:
cursor.execute("rename table %s to %s" % (self.__table_name, newTableName))
except sel.SQLError, reason:
pass
self.setTableName(newTableName)
def getTableColumnsFromDB(self):
return self.db.listFieldsDict(self.__table_name)
def checkTable(self, warnflag=1):
invalidDBCols = {}
invalidAppCols = {}
dbcolumns = self.getTableColumnsFromDB()
for coldef in self.__column_list:
colname = coldef[0]
dbcoldef = dbcolumns.get(colname, None)
if dbcoldef is None:
invalidAppCols[colname] = 1
for colname, row in dbcolumns.items():
coldef = self.__col_def_hash.get(colname, None)
if coldef is None:
invalidDBCols[colname] = 1
if warnflag == 1:
if invalidDBCols:
print "----- WARNING ------------------------------------------"
print " There are columns defined in the database schema that do"
print " not match the application's schema."
print " columns:", invalidDBCols.keys()
print "--------------------------------------------------------"
if invalidAppCols:
print "----- WARNING ------------------------------------------"
print " There are new columns defined in the application schema"
print " that do not match the database's schema."
print " columns:", invalidAppCols.keys()
print "--------------------------------------------------------"
return invalidAppCols, invalidDBCols
def alterTableToMatch(self):
raise "Unimplemented Error!"
def addIndex(self, columns, indexName=None, unique=0):
if indexName is None:
indexName = self.getTableName() + "_index_" + string.join(columns, "_")
self.__indices[indexName] = (columns, unique)
def createIndex(self, columns, indexName=None, unique=0, cursor=None):
if cursor is None: cursor = self.db.defaultCursor()
cols = string.join(columns, ",")
if indexName is None:
indexName = self.getTableName() + "_index_" + string.join(columns, "_")
uniquesql = ""
if unique:
uniquesql = " unique"
sql = "create %s index %s on %s (%s)" % (uniquesql, indexName, self.getTableName(), cols)
warn("creating index", sql)
cursor.execute(sql)
## Column Definition
def getColumnDef(self,column_name):
try:
return self.__col_def_hash[column_name]
except KeyError:
try:
return self.__vcol_def_hash[column_name]
except KeyError:
raise eNoSuchColumn, "no column (%s) on table %s" % (column_name,self.__table_name)
def getColumnList(self):
return self.__column_list + self.__vcolumn_list
def getAppColumnList(self): return self.__column_list
def databaseSizeForData_ColumnName_(self,data,col_name):
try:
col_def = self.__col_def_hash[col_name]
except KeyError:
try:
col_def = self.__vcol_def_hash[col_name]
except KeyError:
raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
c_name,c_type,c_options = col_def
if c_type == kBigString:
if c_options.get("compress_ok",0) and self.db.compression_enabled:
z_size = len(zlib.compress(data,9))
r_size = len(data)
if z_size < r_size:
return z_size
else:
return r_size
else:
return len(data)
else:
# really simplistic database size computation:
try:
a = data[0]
return len(data)
except:
return 4
def columnType(self, col_name):
try:
col_def = self.__col_def_hash[col_name]
except KeyError:
try:
col_def = self.__vcol_def_hash[col_name]
except KeyError:
raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
c_name,c_type,c_options = col_def
return c_type
def convertDataForColumn(self,data,col_name):
try:
col_def = self.__col_def_hash[col_name]
except KeyError:
try:
col_def = self.__vcol_def_hash[col_name]
except KeyError:
raise eNoSuchColumn, "no column (%s) on table %s" % (col_name,self.__table_name)
c_name,c_type,c_options = col_def
if c_type == kIncInteger:
raise eInvalidData, "invalid operation for column (%s:%s) on table (%s)" % (col_name,c_type,self.__table_name)
if c_type == kInteger:
try:
if data is None: data = 0
else: return long(data)
except (ValueError,TypeError):
raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data),col_name,c_type,self.__table_name)
elif c_type == kReal:
try:
if data is None: data = 0.0
else: return float(data)
except (ValueError,TypeError):
raise eInvalidData, "invalid data (%s) for col (%s:%s) on table (%s)" % (repr(data), col_name,c_type,self.__table_name)
else:
if type(data) == type(long(0)):
return "%d" % data
else:
return str(data)
def getPrimaryKeyList(self):
return self.__primary_key_list
def hasValueColumn(self):
return self.__has_value_column
def hasColumn(self,name):
return self.__col_def_hash.has_key(name)
def hasVColumn(self,name):
return self.__vcol_def_hash.has_key(name)
def _defineRows(self):
raise "can't instantiate base odb.Table type, make a subclass and override _defineRows()"
def __lockColumnsAndInit(self):
# add a 'odb_value column' before we lockdown the table def
if self.__has_value_column:
self.d_addColumn("odb_value",kBigText,default='')
self.__columns_locked = 1
# walk column list and make lookup hashes, primary_key_list, etc..
primary_key_list = []
col_def_hash = {}
for a_col in self.__column_list:
name,type,options = a_col
col_def_hash[name] = a_col
if options.has_key('primarykey'):
primary_key_list.append(name)
self.__col_def_hash = col_def_hash
self.__primary_key_list = primary_key_list
# setup the value columns!
if (not self.__has_value_column) and (len(self.__vcolumn_list) > 0):
raise "can't define vcolumns on table without ValueColumn, call d_addValueColumn() in your _defineRows()"
vcol_def_hash = {}
for a_col in self.__vcolumn_list:
name,type,size_data,options = a_col
vcol_def_hash[name] = a_col
self.__vcol_def_hash = vcol_def_hash
def __checkColumnLock(self):
if self.__columns_locked:
raise "can't change column definitions outside of subclass' _defineRows() method!"
# table definition methods, these are only available while inside the
# subclass's _defineRows method
#
# Ex:
#
# import odb
# class MyTable(odb.Table):
# def _defineRows(self):
# self.d_addColumn("id",kInteger,primarykey = 1,autoincrement = 1)
# self.d_addColumn("name",kVarString,120)
# self.d_addColumn("type",kInteger,
# enum_values = { 0 : "alive", 1 : "dead" }
def d_addColumn(self,col_name,ctype,size=None,primarykey = 0,
notnull = 0,indexed=0,
default=None,unique=0,autoincrement=0,safeupdate=0,
enum_values = None,
no_export = 0,
relations=None,compress_ok=0,int_date=0):
self.__checkColumnLock()
options = {}
options['default'] = default
if primarykey:
options['primarykey'] = primarykey
if unique:
options['unique'] = unique
if indexed:
options['indexed'] = indexed
self.addIndex((col_name,))
if safeupdate:
options['safeupdate'] = safeupdate
if autoincrement:
options['autoincrement'] = autoincrement
if notnull:
options['notnull'] = notnull
if size:
options['size'] = size
if no_export:
options['no_export'] = no_export
if int_date:
if ctype != kInteger:
raise eInvalidData, "can't flag columns int_date unless they are kInteger"
else:
options['int_date'] = int_date
if enum_values:
options['enum_values'] = enum_values
inv_enum_values = {}
for k,v in enum_values.items():
if inv_enum_values.has_key(v):
raise eInvalidData, "enum_values paramater must be a 1 to 1 mapping for Table(%s)" % self.__table_name
else:
inv_enum_values[v] = k
options['inv_enum_values'] = inv_enum_values
if relations:
options['relations'] = relations
for a_relation in relations:
table, foreign_column_name = a_relation
if self.__relations_by_table.has_key(table):
raise eInvalidData, "multiple relations for the same foreign table are not yet supported"
self.__relations_by_table[table] = (col_name,foreign_column_name)
if compress_ok:
if ctype == kBigString:
options['compress_ok'] = 1
else:
raise eInvalidData, "only kBigString fields can be compress_ok=1"
self.__column_list.append( (col_name,ctype,options) )
def d_addValueColumn(self):
self.__checkColumnLock()
self.__has_value_column = 1
def d_addVColumn(self,col_name,type,size=None,default=None):
self.__checkColumnLock()
if (not self.__has_value_column):
raise "can't define VColumns on table without ValueColumn, call d_addValueColumn() first"
options = {}
if default:
options['default'] = default
if size:
options['size'] = size
self.__vcolumn_list.append( (col_name,type,options) )
#####################
# _checkColMatchSpec(col_match_spec,should_match_unique_row = 0)
#
# raise an error if the col_match_spec contains invalid columns, or
# (in the case of should_match_unique_row) if it does not fully specify
# a unique row.
#
# NOTE: we don't currently support where clauses with value column fields!
#
def _fixColMatchSpec(self,col_match_spec, should_match_unique_row = 0):
if type(col_match_spec) == type([]):
if type(col_match_spec[0]) != type((0,)):
raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
elif type(col_match_spec) == type((0,)):
col_match_spec = [ col_match_spec ]
elif type(col_match_spec) == type(None):
if should_match_unique_row:
raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
else:
return None
else:
raise eInvalidMatchSpec, "invalid types in match spec, use [(,)..] or (,)"
if should_match_unique_row:
unique_column_lists = []
# first the primary key list
my_primary_key_list = []
for a_key in self.__primary_key_list:
my_primary_key_list.append(a_key)
# then other unique keys
for a_col in self.__column_list:
col_name,a_type,options = a_col
if options.has_key('unique'):
unique_column_lists.append( (col_name, [col_name]) )
unique_column_lists.append( ('primary_key', my_primary_key_list) )
new_col_match_spec = []
for a_col in col_match_spec:
name,val = a_col
# newname = string.lower(name)
# what is this doing?? - jeske
newname = name
if not self.__col_def_hash.has_key(newname):
raise eNoSuchColumn, "no such column in match spec: '%s'" % newname
new_col_match_spec.append( (newname,val) )
if should_match_unique_row:
for name,a_list in unique_column_lists:
try:
a_list.remove(newname)
except ValueError:
# it's okay if they specify too many columns!
pass
if should_match_unique_row:
for name,a_list in unique_column_lists:
if len(a_list) == 0:
# we matched at least one unique colum spec!
# log("using unique column (%s) for query %s" % (name,col_match_spec))
return new_col_match_spec
raise eNonUniqueMatchSpec, "can't use a non-unique match spec (%s) here" % col_match_spec
return new_col_match_spec
def __buildWhereClause (self, col_match_spec,other_clauses = None):
sql_where_list = []
if not col_match_spec is None:
for m_col in col_match_spec:
m_col_name,m_col_val = m_col
c_name,c_type,c_options = self.__col_def_hash[m_col_name]
if c_type in (kIncInteger, kInteger):
try:
m_col_val_long = long(m_col_val)
except ValueError:
raise ValueError, "invalid literal for long(%s) in table %s" % (repr(m_col_val),self.__table_name)
sql_where_list.append("%s = %d" % (c_name, m_col_val_long))
elif c_type == kReal:
try:
m_col_val_float = float(m_col_val)
except ValueError:
raise ValueError, "invalid literal for float(%s) is table %s" % (repr(m_col_val), self.__table_name)
sql_where_list.append("%s = %s" % (c_name, m_col_val_float))
else:
sql_where_list.append("%s = '%s'" % (c_name, self.db.escape(m_col_val)))
if other_clauses is None:
pass
elif type(other_clauses) == type(""):
sql_where_list = sql_where_list + [other_clauses]
elif type(other_clauses) == type([]):
sql_where_list = sql_where_list + other_clauses
else:
raise eInvalidData, "unknown type of extra where clause: %s" % repr(other_clauses)
return sql_where_list
def __fetchRows(self,col_match_spec,cursor = None, where = None, order_by = None, limit_to = None,
skip_to = None, join = None):
if cursor is None:
cursor = self.db.defaultCursor()
# build column list
sql_columns = []
for name,t,options in self.__column_list:
sql_columns.append(name)
# build join information
joined_cols = []
joined_cols_hash = {}
join_clauses = []
if not join is None:
for a_table,retrieve_foreign_cols in join:
try:
my_col,foreign_col = self.__relations_by_table[a_table]
for a_col in retrieve_foreign_cols:
full_col_name = "%s.%s" % (my_col,a_col)
joined_cols_hash[full_col_name] = 1
joined_cols.append(full_col_name)
sql_columns.append( full_col_name )
join_clauses.append(" left join %s as %s on %s=%s " % (a_table,my_col,my_col,foreign_col))
except KeyError:
eInvalidJoinSpec, "can't find table %s in defined relations for %s" % (a_table,self.__table_name)
# start buildling SQL
sql = "select %s from %s" % (string.join(sql_columns,","),
self.__table_name)
# add join clause
if join_clauses:
sql = sql + string.join(join_clauses," ")
# add where clause elements
sql_where_list = self.__buildWhereClause (col_match_spec,where)
if sql_where_list:
sql = sql + " where %s" % (string.join(sql_where_list," and "))
# add order by clause
if order_by:
sql = sql + " order by %s " % string.join(order_by,",")
# add limit
if not limit_to is None:
if not skip_to is None:
# log("limit,skip = %s,%s" % (limit_to,skip_to))
if self.db.db.__module__ == "sqlite.main":
sql = sql + " limit %s offset %s " % (limit_to,skip_to)
else:
sql = sql + " limit %s, %s" % (skip_to,limit_to)
else:
sql = sql + " limit %s" % limit_to
else:
if not skip_to is None:
raise eInvalidData, "can't specify skip_to without limit_to in MySQL"
dlog(DEV_SELECT,sql)
cursor.execute(sql)
# create defaultRowListClass instance...
return_rows = self.__defaultRowListClass()
# should do fetchmany!
all_rows = cursor.fetchall()
for a_row in all_rows:
data_dict = {}
col_num = 0
# for a_col in cursor.description:
# (name,type_code,display_size,internal_size,precision,scale,null_ok) = a_col
for name in sql_columns:
if self.__col_def_hash.has_key(name) or joined_cols_hash.has_key(name):
# only include declared columns!
if self.__col_def_hash.has_key(name):
c_name,c_type,c_options = self.__col_def_hash[name]
if c_type == kBigString and c_options.get("compress_ok",0) and a_row[col_num]:
try:
a_col_data = zlib.decompress(a_row[col_num])
except zlib.error:
a_col_data = a_row[col_num]
data_dict[name] = a_col_data
elif c_type == kInteger or c_type == kIncInteger:
value = a_row[col_num]
if not value is None:
data_dict[name] = int(value)
else:
data_dict[name] = None
else:
data_dict[name] = a_row[col_num]
else:
data_dict[name] = a_row[col_num]
col_num = col_num + 1
newrowobj = self.__defaultRowClass(self,data_dict,joined_cols = joined_cols)
return_rows.append(newrowobj)
return return_rows
def __deleteRow(self,a_row,cursor = None):
if cursor is None:
cursor = self.db.defaultCursor()
# build the where clause!
match_spec = a_row.getPKMatchSpec()
sql_where_list = self.__buildWhereClause (match_spec)
sql = "delete from %s where %s" % (self.__table_name,
string.join(sql_where_list," and "))
dlog(DEV_UPDATE,sql)
cursor.execute(sql)
def __updateRowList(self,a_row_list,cursor = None):
if cursor is None:
cursor = self.db.defaultCursor()
for a_row in a_row_list:
update_list = a_row.changedList()
# build the set list!
sql_set_list = []
for a_change in update_list:
col_name,col_val,col_inc_val = a_change
c_name,c_type,c_options = self.__col_def_hash[col_name]
if c_type != kIncInteger and col_val is None:
sql_set_list.append("%s = NULL" % c_name)
elif c_type == kIncInteger and col_inc_val is None:
sql_set_list.append("%s = 0" % c_name)
else:
if c_type == kInteger:
sql_set_list.append("%s = %d" % (c_name, long(col_val)))
elif c_type == kIncInteger:
sql_set_list.append("%s = %s + %d" % (c_name,c_name,long(col_inc_val)))
elif c_type == kBigString and c_options.get("compress_ok",0) and self.db.compression_enabled:
compressed_data = zlib.compress(col_val,9)
if len(compressed_data) < len(col_val):
sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(compressed_data)))
else:
sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))
elif c_type == kReal:
sql_set_list.append("%s = %s" % (c_name,float(col_val)))
else:
sql_set_list.append("%s = '%s'" % (c_name, self.db.escape(col_val)))
# build the where clause!
match_spec = a_row.getPKMatchSpec()
sql_where_list = self.__buildWhereClause (match_spec)
if sql_set_list:
sql = "update %s set %s where %s" % (self.__table_name,
string.join(sql_set_list,","),
string.join(sql_where_list," and "))
dlog(DEV_UPDATE,sql)
try:
cursor.execute(sql)
except Exception, reason:
if string.find(str(reason), "Duplicate entry") != -1:
raise eDuplicateKey, reason
raise Exception, reason
a_row.markClean()
def __insertRow(self,a_row_obj,cursor = None,replace=0):
if cursor is None:
cursor = self.db.defaultCursor()
sql_col_list = []
sql_data_list = []
auto_increment_column_name = None
for a_col in self.__column_list:
name,type,options = a_col
try:
data = a_row_obj[name]
sql_col_list.append(name)
if data is None:
sql_data_list.append("NULL")
else:
if type == kInteger or type == kIncInteger:
sql_data_list.append("%d" % data)
elif type == kBigString and options.get("compress_ok",0) and self.db.compression_enabled:
compressed_data = zlib.compress(data,9)
if len(compressed_data) < len(data):
sql_data_list.append("'%s'" % self.db.escape(compressed_data))
else:
sql_data_list.append("'%s'" % self.db.escape(data))
elif type == kReal:
sql_data_list.append("%s" % data)
else:
sql_data_list.append("'%s'" % self.db.escape(data))
except KeyError:
if options.has_key("autoincrement"):
if auto_increment_column_name:
raise eInternalError, "two autoincrement columns (%s,%s) in table (%s)" % (auto_increment_column_name, name,self.__table_name)
else:
auto_increment_column_name = name
if replace:
sql = "replace into %s (%s) values (%s)" % (self.__table_name,
string.join(sql_col_list,","),
string.join(sql_data_list,","))
else:
sql = "insert into %s (%s) values (%s)" % (self.__table_name,
string.join(sql_col_list,","),
string.join(sql_data_list,","))
dlog(DEV_UPDATE,sql)
try:
cursor.execute(sql)
except Exception, reason:
# sys.stderr.write("errror in statement: " + sql + "\n")
log("error in statement: " + sql + "\n")
if string.find(str(reason), "Duplicate entry") != -1:
raise eDuplicateKey, reason
raise Exception, reason
if auto_increment_column_name:
if cursor.__module__ == "sqlite.main":
a_row_obj[auto_increment_column_name] = cursor.lastrowid
elif cursor.__module__ == "MySQLdb.cursors":
a_row_obj[auto_increment_column_name] = cursor.insert_id()
else:
# fallback to acting like mysql
a_row_obj[auto_increment_column_name] = cursor.insert_id()
# ----------------------------------------------------
# Helper methods for Rows...
# ----------------------------------------------------
#####################
# r_deleteRow(a_row_obj,cursor = None)
#
# normally this is called from within the Row "delete()" method
# but you can call it yourself if you want
#
def r_deleteRow(self,a_row_obj, cursor = None):
curs = cursor
self.__deleteRow(a_row_obj, cursor = curs)
#####################
# r_updateRow(a_row_obj,cursor = None)
#
# normally this is called from within the Row "save()" method
# but you can call it yourself if you want
#
def r_updateRow(self,a_row_obj, cursor = None):
curs = cursor
self.__updateRowList([a_row_obj], cursor = curs)
#####################
# InsertRow(a_row_obj,cursor = None)
#
# normally this is called from within the Row "save()" method
# but you can call it yourself if you want
#
def r_insertRow(self,a_row_obj, cursor = None,replace=0):
curs = cursor
self.__insertRow(a_row_obj, cursor = curs,replace=replace)
# ----------------------------------------------------
# Public Methods
# ----------------------------------------------------
#####################
# deleteRow(col_match_spec)
#
# The col_match_spec paramaters must include all primary key columns.
#
# Ex:
# a_row = tbl.fetchRow( ("order_id", 1) )
# a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )
def deleteRow(self,col_match_spec, where=None):
n_match_spec = self._fixColMatchSpec(col_match_spec)
cursor = self.db.defaultCursor()
# build sql where clause elements
sql_where_list = self.__buildWhereClause (n_match_spec,where)
if not sql_where_list:
return
sql = "delete from %s where %s" % (self.__table_name, string.join(sql_where_list," and "))
dlog(DEV_UPDATE,sql)
cursor.execute(sql)
#####################
# fetchRow(col_match_spec)
#
# The col_match_spec paramaters must include all primary key columns.
#
# Ex:
# a_row = tbl.fetchRow( ("order_id", 1) )
# a_row = tbl.fetchRow( [ ("order_id", 1), ("enterTime", now) ] )
def fetchRow(self, col_match_spec, cursor = None):
n_match_spec = self._fixColMatchSpec(col_match_spec, should_match_unique_row = 1)
rows = self.__fetchRows(n_match_spec, cursor = cursor)
if len(rows) == 0:
raise eNoMatchingRows, "no row matches %s" % repr(n_match_spec)
if len(rows) > 1:
raise eInternalError, "unique where clause shouldn't return > 1 row"
return rows[0]
#####################
# fetchRows(col_match_spec)
#
# Ex:
# a_row_list = tbl.fetchRows( ("order_id", 1) )
# a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )
def fetchRows(self, col_match_spec = None, cursor = None,
where = None, order_by = None, limit_to = None,
skip_to = None, join = None):
n_match_spec = self._fixColMatchSpec(col_match_spec)
return self.__fetchRows(n_match_spec,
cursor = cursor,
where = where,
order_by = order_by,
limit_to = limit_to,
skip_to = skip_to,
join = join)
def fetchRowCount (self, col_match_spec = None,
cursor = None, where = None):
n_match_spec = self._fixColMatchSpec(col_match_spec)
sql_where_list = self.__buildWhereClause (n_match_spec,where)
sql = "select count(*) from %s" % self.__table_name
if sql_where_list:
sql = "%s where %s" % (sql,string.join(sql_where_list," and "))
if cursor is None:
cursor = self.db.defaultCursor()
dlog(DEV_SELECT,sql)
cursor.execute(sql)
try:
count, = cursor.fetchone()
except TypeError:
count = 0
return count
#####################
# fetchAllRows()
#
# Ex:
# a_row_list = tbl.fetchRows( ("order_id", 1) )
# a_row_list = tbl.fetchRows( [ ("order_id", 1), ("enterTime", now) ] )
def fetchAllRows(self):
try:
return self.__fetchRows([])
except eNoMatchingRows:
# else return empty list...
return self.__defaultRowListClass()
def newRow(self,replace=0):
row = self.__defaultRowClass(self,None,create=1,replace=replace)
for (cname, ctype, opts) in self.__column_list:
if opts['default'] is not None and ctype is not kIncInteger:
row[cname] = opts['default']
return row
class Row:
__instance_data_locked = 0
def subclassinit(self):
pass
def __init__(self,_table,data_dict,create=0,joined_cols = None,replace=0):
self._inside_getattr = 0 # stop recursive __getattr__
self._table = _table
self._should_insert = create or replace
self._should_replace = replace
self._rowInactive = None
self._joinedRows = []
self.__pk_match_spec = None
self.__vcoldata = {}
self.__inc_coldata = {}
self.__joined_cols_dict = {}
for a_col in joined_cols or []:
self.__joined_cols_dict[a_col] = 1
if create:
self.__coldata = {}
else:
if type(data_dict) != type({}):
raise eInternalError, "rowdict instantiate with bad data_dict"
self.__coldata = data_dict
self.__unpackVColumn()
self.markClean()
self.subclassinit()
self.__instance_data_locked = 1
def joinRowData(self,another_row):
self._joinedRows.append(another_row)
def getPKMatchSpec(self):
return self.__pk_match_spec
def markClean(self):
self.__vcolchanged = 0
self.__colchanged_dict = {}
for key in self.__inc_coldata.keys():
self.__coldata[key] = self.__coldata.get(key, 0) + self.__inc_coldata[key]
self.__inc_coldata = {}
if not self._should_insert:
# rebuild primary column match spec
new_match_spec = []
for col_name in self._table.getPrimaryKeyList():
try:
rdata = self[col_name]
except KeyError:
raise eInternalError, "must have primary key data filled in to save %s:Row(col:%s)" % (self._table.getTableName(),col_name)
new_match_spec.append( (col_name, rdata) )
self.__pk_match_spec = new_match_spec
def __unpackVColumn(self):
if self._table.hasValueColumn():
pass
def __packVColumn(self):
if self._table.hasValueColumn():
pass
## ----- utility stuff ----------------------------------
def __del__(self):
# check for unsaved changes
changed_list = self.changedList()
if len(changed_list):
info = "unsaved Row for table (%s) lost, call discard() to avoid this error. Lost changes: %s\n" % (self._table.getTableName(), repr(changed_list)[:256])
if 0:
raise eUnsavedObjectLost, info
else:
sys.stderr.write(info)
def __repr__(self):
return "Row from (%s): %s" % (self._table.getTableName(),repr(self.__coldata) + repr(self.__vcoldata))
## ---- class emulation --------------------------------
def __getattr__(self,key):
if self._inside_getattr:
raise AttributeError, "recursively called __getattr__ (%s,%s)" % (key,self._table.getTableName())
try:
self._inside_getattr = 1
try:
return self[key]
except KeyError:
if self._table.hasColumn(key) or self._table.hasVColumn(key):
return None
else:
raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
finally:
self._inside_getattr = 0
def __setattr__(self,key,val):
if not self.__instance_data_locked:
self.__dict__[key] = val
else:
my_dict = self.__dict__
if my_dict.has_key(key):
my_dict[key] = val
else:
# try and put it into the rowdata
try:
self[key] = val
except KeyError, reason:
raise AttributeError, reason
## ---- dict emulation ---------------------------------
def __getitem__(self,key):
self.checkRowActive()
try:
c_type = self._table.columnType(key)
except eNoSuchColumn, reason:
# Ugh, this sucks, we can't determine the type for a joined
# row, so we just default to kVarString and let the code below
# determine if this is a joined column or not
c_type = kVarString
if c_type == kIncInteger:
c_data = self.__coldata.get(key, 0)
if c_data is None: c_data = 0
i_data = self.__inc_coldata.get(key, 0)
if i_data is None: i_data = 0
return c_data + i_data
try:
return self.__coldata[key]
except KeyError:
try:
return self.__vcoldata[key]
except KeyError:
for a_joined_row in self._joinedRows:
try:
return a_joined_row[key]
except KeyError:
pass
raise KeyError, "unknown column %s in %s" % (key,self)
def __setitem__(self,key,data):
self.checkRowActive()
try:
newdata = self._table.convertDataForColumn(data,key)
except eNoSuchColumn, reason:
raise KeyError, reason
if self._table.hasColumn(key):
self.__coldata[key] = newdata
self.__colchanged_dict[key] = 1
elif self._table.hasVColumn(key):
self.__vcoldata[key] = newdata
self.__vcolchanged = 1
else:
for a_joined_row in self._joinedRows:
try:
a_joined_row[key] = data
return
except KeyError:
pass
raise KeyError, "unknown column name %s" % key
def __delitem__(self,key,data):
self.checkRowActive()
if self.table.hasVColumn(key):
del self.__vcoldata[key]
else:
for a_joined_row in self._joinedRows:
try:
del a_joined_row[key]
return
except KeyError:
pass
raise KeyError, "unknown column name %s" % key
def copyFrom(self,source):
for name,t,options in self._table.getColumnList():
if not options.has_key("autoincrement"):
self[name] = source[name]
# make sure that .keys(), and .items() come out in a nice order!
def keys(self):
self.checkRowActive()
key_list = []
for name,t,options in self._table.getColumnList():
key_list.append(name)
for name in self.__joined_cols_dict.keys():
key_list.append(name)
for a_joined_row in self._joinedRows:
key_list = key_list + a_joined_row.keys()
return key_list
def items(self):
self.checkRowActive()
item_list = []
for name,t,options in self._table.getColumnList():
item_list.append( (name,self[name]) )
for name in self.__joined_cols_dict.keys():
item_list.append( (name,self[name]) )
for a_joined_row in self._joinedRows:
item_list = item_list + a_joined_row.items()
return item_list
def values(elf):
self.checkRowActive()
value_list = self.__coldata.values() + self.__vcoldata.values()
for a_joined_row in self._joinedRows:
value_list = value_list + a_joined_row.values()
return value_list
def __len__(self):
self.checkRowActive()
my_len = len(self.__coldata) + len(self.__vcoldata)
for a_joined_row in self._joinedRows:
my_len = my_len + len(a_joined_row)
return my_len
def has_key(self,key):
self.checkRowActive()
if self.__coldata.has_key(key) or self.__vcoldata.has_key(key):
return 1
else:
for a_joined_row in self._joinedRows:
if a_joined_row.has_key(key):
return 1
return 0
def get(self,key,default = None):
self.checkRowActive()
if self.__coldata.has_key(key):
return self.__coldata[key]
elif self.__vcoldata.has_key(key):
return self.__vcoldata[key]
else:
for a_joined_row in self._joinedRows:
try:
return a_joined_row.get(key,default)
except eNoSuchColumn:
pass
if self._table.hasColumn(key):
return default
raise eNoSuchColumn, "no such column %s" % key
def inc(self,key,count=1):
self.checkRowActive()
if self._table.hasColumn(key):
try:
self.__inc_coldata[key] = self.__inc_coldata[key] + count
except KeyError:
self.__inc_coldata[key] = count
self.__colchanged_dict[key] = 1
else:
raise AttributeError, "unknown field '%s' in Row(%s)" % (key,self._table.getTableName())
## ----------------------------------
## real interface
def fillDefaults(self):
for field_def in self._table.fieldList():
name,type,size,options = field_def
if options.has_key("default"):
self[name] = options["default"]
###############
# changedList()
#
# returns a list of tuples for the columns which have changed
#
# changedList() -> [ ('name', 'fred'), ('age', 20) ]
def changedList(self):
if self.__vcolchanged:
self.__packVColumn()
changed_list = []
for a_col in self.__colchanged_dict.keys():
changed_list.append( (a_col,self.get(a_col,None),self.__inc_coldata.get(a_col,None)) )
return changed_list
def discard(self):
self.__coldata = None
self.__vcoldata = None
self.__colchanged_dict = {}
self.__vcolchanged = 0
def delete(self,cursor = None):
self.checkRowActive()
fromTable = self._table
curs = cursor
fromTable.r_deleteRow(self,cursor=curs)
self._rowInactive = "deleted"
def save(self,cursor = None):
toTable = self._table
self.checkRowActive()
if self._should_insert:
toTable.r_insertRow(self,replace=self._should_replace)
self._should_insert = 0
self._should_replace = 0
self.markClean() # rebuild the primary key list
else:
curs = cursor
toTable.r_updateRow(self,cursor = curs)
# the table will mark us clean!
# self.markClean()
def checkRowActive(self):
if self._rowInactive:
raise eInvalidData, "row is inactive: %s" % self._rowInactive
def databaseSizeForColumn(self,key):
return self._table.databaseSizeForData_ColumnName_(self[key],key)
if __name__ == "__main__":
print "run odb_test.py"