Ejemplo n.º 1
0
// Build metadata for each result column
func buildResultColumnDefinitions(stmtHandle odbc.SQLHandle, sqlStmt string) ([]resultColumnDef, odbc.SQLReturn) {

	//Get number of result columns
	var numColumns odbc.SQLSMALLINT
	ret := odbc.SQLNumResultCols(stmtHandle, &numColumns)
	if isError(ret) {
		errorStatement(stmtHandle, sqlStmt)
	}

	resultColumnDefs := make([]resultColumnDef, 0, numColumns)
	for colNum, lNumColumns := odbc.SQLSMALLINT(1), numColumns; colNum <= lNumColumns; colNum++ {
		//Get odbc.SQL type
		var sqlType odbc.SQLLEN
		ret := odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_TYPE, 0, 0, nil, &sqlType)
		if isError(ret) {
			errorStatement(stmtHandle, sqlStmt)
		}

		/* Disabled because it is no longer needed
		//Get length
		var length odbc.SQLLEN
		ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_LENGTH, 0, 0, nil, &length)
		if isError(ret) {
			errorStatement(stmtHandle, sqlStmt)
		}

		//If the type is a CHAR or VARCHAR, add 4 to the length
		if odbc.SQLDataType(sqlType) == odbc.SQL_CHAR || odbc.SQLDataType(sqlType)  == odbc.SQL_VARCHAR || odbc.SQLDataType(sqlType)  == odbc.SQL_WCHAR || odbc.SQLDataType(sqlType)  == odbc.SQL_WVARCHAR {
			length = length + 4
		} */

		//Get name
		const namelength = 1000
		nameArr := make([]uint16, namelength)
		ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_DESC_LABEL, uintptr(unsafe.Pointer(&nameArr[0])), namelength, nil, nil)
		if isError(ret) {
			errorStatement(stmtHandle, sqlStmt)
		}
		name := syscall.UTF16ToString(nameArr)

		//For numeric and decimal types, get the precision
		var precision odbc.SQLLEN
		if odbc.SQLDataType(sqlType) == odbc.SQL_NUMERIC || odbc.SQLDataType(sqlType) == odbc.SQL_DECIMAL {
			ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_PRECISION, 0, 0, nil, &precision)
			if isError(ret) {
				errorStatement(stmtHandle, sqlStmt)
			}
		}

		//For numeric and decimal types, get the scale
		var scale odbc.SQLLEN
		if odbc.SQLDataType(sqlType) == odbc.SQL_NUMERIC || odbc.SQLDataType(sqlType) == odbc.SQL_DECIMAL {
			ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_SCALE, 0, 0, nil, &scale)
			if isError(ret) {
				errorStatement(stmtHandle, sqlStmt)
			}
		}

		col := resultColumnDef{RecNum: colNum, DataType: odbc.SQLDataType(sqlType), Name: name, Precision: precision, Scale: scale}
		resultColumnDefs = append(resultColumnDefs, col)
	}

	return resultColumnDefs, odbc.SQL_SUCCESS
}
Ejemplo n.º 2
0
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) {
	//Clear any existing bind values
	stmt.bindValues = make([]interface{}, len(args)+1)

	//Bind the parameters
	bindParameters, err := stmt.convertToBindParameters(args)
	if err != nil {
		return nil, err
	}
	stmt.bindParameters(bindParameters)

	//If rows is not nil, close rows and set to nil
	if stmt.rows != nil {
		stmt.rows.Close()
		stmt.rows = nil
	}

	//Execute SQL statement
	sqlStmtSqlPtr := (*odbc.SQLCHAR)(unsafe.Pointer(syscall.StringToUTF16Ptr(stmt.sqlStmt)))
	ret := odbc.SQLExecDirect(stmt.handle, sqlStmtSqlPtr, odbc.SQL_NTS)
	if isError(ret) {
		return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues()))
	}

	//Get row descriptor handle
	var descRowHandle odbc.SQLHandle
	ret = odbc.SQLGetStmtAttr(stmt.handle, odbc.SQL_ATTR_APP_ROW_DESC, uintptr(unsafe.Pointer(&descRowHandle)), 0, nil)
	if isError(ret) {
		return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues()))
	}

	//Check to see if the query option ResultSetNum was passed and if so, iterate through result sets
	optionValue, optionFound := getOptionValue(stmt.queryOptions, ResultSetNum)
	if optionFound {
		for counter, resultSetNum := 0, int(optionValue.(float64)); counter < resultSetNum; counter++ {
			ret := odbc.SQLMoreResults(stmt.handle)
			if isError(ret) {
				return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt))
			}
		}
	} else {
		//If query option ResultSetNum was not passed, iterate through result sets until at least one column is found
		for {
			var numColumns odbc.SQLSMALLINT
			ret := odbc.SQLNumResultCols(stmt.handle, &numColumns)
			if isError(ret) {
				return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt))
			}
			if numColumns > 0 {
				break
			} else {
				ret := odbc.SQLMoreResults(stmt.handle)
				if isError(ret) {
					return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt))
				}
			}
		}
	}

	//Get definition of result columns
	resultColumnDefs, ret := buildResultColumnDefinitions(stmt.handle, stmt.sqlStmt)
	if isError(ret) {
		return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues()))
	}

	//Make a slice of the column names
	columnNames := make([]string, len(resultColumnDefs))
	for index, resultCol := range resultColumnDefs {
		columnNames[index] = fmt.Sprint(resultCol.Name)
	}

	//Create rows
	stmt.rows = &rows{handle: stmt.handle, descHandle: descRowHandle, isBeforeFirst: true, resultColumnDefs: resultColumnDefs, resultColumnNames: columnNames, sqlStmt: stmt.sqlStmt}

	//Add a finalizer
	runtime.SetFinalizer(stmt.rows, (*rows).Close)

	return stmt.rows, nil
}