Example #1
0
// Commit or rollback transaction in consistent manner
func (tx *transaction) completeTransaction(completeType odbc.SQLSMALLINT) error {
	//Complete transaction by either committing or rolling back
	ret := odbc.SQLEndTran(odbc.SQL_HANDLE_DBC, tx.conn.handle, completeType)
	if isError(ret) {
		return errorConnection(tx.conn.handle)
	}

	//Make transaction as finished and turn auto commit back on
	tx.conn.isTransactionActive = false
	ret = odbc.SQLSetConnectAttr(tx.conn.handle, odbc.SQL_ATTR_AUTOCOMMIT, odbc.SQLPOINTER(odbc.SQL_AUTOCOMMIT_ON), 0, nil)
	if isError(ret) {
		return errorConnection(tx.conn.handle)
	}
	return nil
}
Example #2
0
// Begin starts and returns a new transaction
// Only one transaction is supported at a time for a connection
func (c *connection) Begin() (driver.Tx, error) {
	// Do not allow a  new transaction if one already exists
	if c.isTransactionActive {
		return nil, fmt.Errorf("Transaction already active for connection")
	}

	ret := odbc.SQLSetConnectAttr(c.handle, odbc.SQL_ATTR_AUTOCOMMIT, odbc.SQLPOINTER(odbc.SQL_AUTOCOMMIT_OFF), 0, nil)
	if isError(ret) {
		return nil, errorConnection(c.handle)
	}
	c.isTransactionActive = true

	tx := &transaction{conn: c}
	return tx, nil
}
Example #3
0
// Prepare returns a prepared statement, bound to this connection
func (c *connection) Prepare(query string) (driver.Stmt, error) {

	// Allocate the statement handle
	var stmtHandle odbc.SQLHandle
	ret := odbc.SQLAllocHandle(odbc.SQL_HANDLE_STMT, c.handle, &stmtHandle)
	if isError(ret) {
		return nil, errorConnection(c.handle)
	}

	// Set the query timeout
	ret = odbc.SQLSetStmtAttr(stmtHandle, odbc.SQL_ATTR_QUERY_TIMEOUT, odbc.SQLPOINTER(queryTimeout.Seconds()), odbc.SQL_IS_INTEGER)
	if isError(ret) {
		return nil, errorStatement(stmtHandle, query)
	}

	// Get the statement descriptor table
	var stmtDescHandle odbc.SQLHandle
	ret = odbc.SQLGetStmtAttr(stmtHandle, odbc.SQL_ATTR_APP_PARAM_DESC, uintptr(unsafe.Pointer(&stmtDescHandle)), 0, nil)
	if isError(ret) {
		return nil, errorConnection(c.handle)
	}

	// Parse query options
	queryOptions, err := parseQueryOptions(query)
	if err != nil {
		return nil, err
	}

	// Remove query options from SQL query
	query = removeOptions(query)

	// Create new statement
	stmt := &statement{handle: stmtHandle, stmtDescHandle: stmtDescHandle, sqlStmt: query, conn: c, queryOptions: queryOptions}

	// Add to map of statements owned by the connection
	c.statements[stmt] = true

	//Add a finalizer
	runtime.SetFinalizer(stmt, (*statement).Close)

	return stmt, nil
}
Example #4
0
// Close invalidates and potentially stops any current
// prepared statements and transactions, marking this
// connection as no longer in use.
func (c *connection) Close() error {

	// Verify that connHandle is valid
	if c.handle == 0 {
		return nil
	}

	// Verify that connection has not already been closed
	if c.isClosed {
		return nil
	}

	var err error

	// Close all of the statements owned by the connection
	for key, _ := range c.statements {
		// Skip the statement if it is already nil
		if isNil(key) {
			continue
		}
		key.Close()
	}
	c.statements = nil

	// If the transaction is active, roll it back
	if c.isTransactionActive {
		ret := odbc.SQLEndTran(odbc.SQL_HANDLE_DBC, c.handle, odbc.SQL_ROLLBACK)
		if isError(ret) {
			err = errorConnection(c.handle)
		}

		//Turn AutoCommit back on
		ret = odbc.SQLSetConnectAttr(c.handle, odbc.SQL_ATTR_AUTOCOMMIT, odbc.SQLPOINTER(odbc.SQL_AUTOCOMMIT_ON), 0, nil)
		if isError(ret) {
			err = errorConnection(c.handle)
		}
	}

	// Disconnect connection
	ret := odbc.SQLDisconnect(c.handle)
	if isError(ret) {
		err = errorConnection(c.handle)
	}

	// Deallocate connection
	ret = odbc.SQLFreeHandle(odbc.SQL_HANDLE_DBC, c.handle)
	if isError(ret) {
		err = errorConnection(c.handle)
	}

	// Clear the handle
	c.handle = 0

	// Set connection to closed
	c.isClosed = true

	//Clear the finalizer
	runtime.SetFinalizer(c, nil)

	// Return any error
	if err != nil {
		return err
	}

	return nil
}
Example #5
0
func (stmt *statement) bindDateTime(index int, value time.Time, direction ParameterDirection) error {
	var bindVal odbc.SQL_TIMESTAMP_STRUCT
	bindVal.Year = odbc.SQLSMALLINT(value.Year())
	bindVal.Month = odbc.SQLUSMALLINT(value.Month())
	bindVal.Day = odbc.SQLUSMALLINT(value.Day())
	bindVal.Hour = odbc.SQLUSMALLINT(value.Hour())
	bindVal.Minute = odbc.SQLUSMALLINT(value.Minute())
	bindVal.Second = odbc.SQLUSMALLINT(value.Second())

	stmt.bindValues[index] = &bindVal
	ret := odbc.SQLBindParameter(stmt.handle, odbc.SQLUSMALLINT(index), direction.SQLBindParameterType(), odbc.SQL_C_TIMESTAMP, odbc.SQL_TIMESTAMP, 23, 0, odbc.SQLPOINTER(unsafe.Pointer(stmt.bindValues[index].(*odbc.SQL_TIMESTAMP_STRUCT))), 16, nil)
	if isError(ret) {
		return errorStatement(stmt.handle, fmt.Sprintf("Bind index: %v, Value: %v", index, bindVal))
	}
	return nil
}
Example #6
0
func (stmt *statement) bindNumeric(index int, value float64, precision int, scale int, direction ParameterDirection) error {
	stmt.bindValues[index] = &value
	ret := odbc.SQLBindParameter(stmt.handle, odbc.SQLUSMALLINT(index), direction.SQLBindParameterType(), odbc.SQL_C_DOUBLE, odbc.SQL_DOUBLE, 0, 0, odbc.SQLPOINTER(unsafe.Pointer(stmt.bindValues[index].(*float64))), 0, nil)
	/* Must convert to SQL_NUMERIC_STRUCT for decimal to work - http://support.microsoft.com/kb/181254
	 ret := odbc.SQLBindParameter(stmt.handle, uint16(index), direction.SQLBindParameterType(), odbc.SQL_C_NUMERIC, odbc.SQL_DECIMAL, uint64(precision), int16(scale), uintptr(unsafe.Pointer(&bindVal)), 0, nil)
	odbc.SQLSetDescField(stmt.stmtDescHandle, odbc.SQLSMALLINT(index), odbc.SQL_DESC_TYPE, odbc.SQL_NUMERIC, 0)
	odbc.SQLSetDescField(stmt.stmtDescHandle, odbc.SQLSMALLINT(index), odbc.SQL_DESC_PRECISION, int32(precision), 0)
	odbc.SQLSetDescField(stmt.stmtDescHandle, odbc.SQLSMALLINT(index), odbc.SQL_DESC_SCALE, int32(scale), 0) */
	if isError(ret) {
		return errorStatement(stmt.handle, fmt.Sprintf("Bind index: %v, Value: %v", index, value))
	}
	return nil
}
Example #7
0
func (stmt *statement) bindBool(index int, value bool, direction ParameterDirection) error {
	stmt.bindValues[index] = &value
	ret := odbc.SQLBindParameter(stmt.handle, odbc.SQLUSMALLINT(index), direction.SQLBindParameterType(), odbc.SQL_C_BIT, odbc.SQL_BIT, 0, 0, odbc.SQLPOINTER(unsafe.Pointer(stmt.bindValues[index].(*bool))), 0, nil)
	if isError(ret) {
		return errorStatement(stmt.handle, fmt.Sprintf("Bind index: %v, Value: %v", index, value))
	}
	return nil
}
Example #8
0
func (stmt *statement) bindByteArray(index int, value []byte, direction ParameterDirection) error {
	// Store both value and lenght, because we need a pointer to the lenght in
	// the last parameter of SQLBindParamter. Otherwise the data is assumed to
	// be a null terminated string.
	bindVal := &struct {
		value  []byte
		length int
	}{
		value,
		len(value),
	}
	sqlType := odbc.SQL_VARBINARY
	if bindVal.length > 4000 {
		sqlType = odbc.SQL_LONGVARBINARY
	}

	// Protect against index out of range on &bindVal.value[0] when value is zero-length.
	// We can't pass NULL to SQLBindParameter so this is needed, it will still
	// write a zero length value to the database since the length parameter is
	// zero.
	if bindVal.length == 0 {
		bindVal.value = []byte{'\x00'}
	}

	ret := odbc.SQLBindParameter(stmt.handle, odbc.SQLUSMALLINT(index), direction.SQLBindParameterType(), odbc.SQL_C_BINARY, sqlType, odbc.SQLULEN(bindVal.length), 0, odbc.SQLPOINTER(unsafe.Pointer(&bindVal.value[0])), 0, (*odbc.SQLLEN)(unsafe.Pointer(&bindVal.length)))
	if isError(ret) {
		return errorStatement(stmt.handle, fmt.Sprintf("Bind index: %v, Value: %v", index, value))
	}
	return nil
}
Example #9
0
func (stmt *statement) bindString(index int, value string, length int, direction ParameterDirection) error {
	if length == 0 {
		length = len(value)
	}
	stmt.bindValues[index] = syscall.StringToUTF16(value)
	var sqlType odbc.SQLDataType
	if length < 4000 {
		sqlType = odbc.SQL_VARCHAR
	} else {
		sqlType = odbc.SQL_LONGVARCHAR
	}
	ret := odbc.SQLBindParameter(stmt.handle, odbc.SQLUSMALLINT(index), direction.SQLBindParameterType(), odbc.SQL_C_WCHAR, sqlType, odbc.SQLULEN(length), 0, odbc.SQLPOINTER(unsafe.Pointer(&stmt.bindValues[index].([]uint16)[0])), 0, nil)
	if isError(ret) {
		return errorStatement(stmt.handle, fmt.Sprintf("Bind index: %v, Value: %v", index, value))
	}
	return nil
}