// Build metadata for each result column func buildResultColumnDefinitions(stmtHandle odbc.SQLHandle, sqlStmt string) ([]resultColumnDef, odbc.SQLReturn) { //Get number of result columns var numColumns odbc.SQLSMALLINT ret := odbc.SQLNumResultCols(stmtHandle, &numColumns) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } resultColumnDefs := make([]resultColumnDef, 0, numColumns) for colNum, lNumColumns := odbc.SQLSMALLINT(1), numColumns; colNum <= lNumColumns; colNum++ { //Get odbc.SQL type var sqlType odbc.SQLLEN ret := odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_TYPE, 0, 0, nil, &sqlType) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } /* Disabled because it is no longer needed //Get length var length odbc.SQLLEN ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_LENGTH, 0, 0, nil, &length) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } //If the type is a CHAR or VARCHAR, add 4 to the length if odbc.SQLDataType(sqlType) == odbc.SQL_CHAR || odbc.SQLDataType(sqlType) == odbc.SQL_VARCHAR || odbc.SQLDataType(sqlType) == odbc.SQL_WCHAR || odbc.SQLDataType(sqlType) == odbc.SQL_WVARCHAR { length = length + 4 } */ //Get name const namelength = 1000 nameArr := make([]uint16, namelength) ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_DESC_LABEL, uintptr(unsafe.Pointer(&nameArr[0])), namelength, nil, nil) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } name := syscall.UTF16ToString(nameArr) //For numeric and decimal types, get the precision var precision odbc.SQLLEN if odbc.SQLDataType(sqlType) == odbc.SQL_NUMERIC || odbc.SQLDataType(sqlType) == odbc.SQL_DECIMAL { ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_PRECISION, 0, 0, nil, &precision) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } } //For numeric and decimal types, get the scale var scale odbc.SQLLEN if odbc.SQLDataType(sqlType) == odbc.SQL_NUMERIC || odbc.SQLDataType(sqlType) == odbc.SQL_DECIMAL { ret = odbc.SQLColAttribute(stmtHandle, odbc.SQLUSMALLINT(colNum), odbc.SQL_COLUMN_SCALE, 0, 0, nil, &scale) if isError(ret) { errorStatement(stmtHandle, sqlStmt) } } col := resultColumnDef{RecNum: colNum, DataType: odbc.SQLDataType(sqlType), Name: name, Precision: precision, Scale: scale} resultColumnDefs = append(resultColumnDefs, col) } return resultColumnDefs, odbc.SQL_SUCCESS }
func (stmt *statement) Query(args []driver.Value) (driver.Rows, error) { //Clear any existing bind values stmt.bindValues = make([]interface{}, len(args)+1) //Bind the parameters bindParameters, err := stmt.convertToBindParameters(args) if err != nil { return nil, err } stmt.bindParameters(bindParameters) //If rows is not nil, close rows and set to nil if stmt.rows != nil { stmt.rows.Close() stmt.rows = nil } //Execute SQL statement sqlStmtSqlPtr := (*odbc.SQLCHAR)(unsafe.Pointer(syscall.StringToUTF16Ptr(stmt.sqlStmt))) ret := odbc.SQLExecDirect(stmt.handle, sqlStmtSqlPtr, odbc.SQL_NTS) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues())) } //Get row descriptor handle var descRowHandle odbc.SQLHandle ret = odbc.SQLGetStmtAttr(stmt.handle, odbc.SQL_ATTR_APP_ROW_DESC, uintptr(unsafe.Pointer(&descRowHandle)), 0, nil) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues())) } //Check to see if the query option ResultSetNum was passed and if so, iterate through result sets optionValue, optionFound := getOptionValue(stmt.queryOptions, ResultSetNum) if optionFound { for counter, resultSetNum := 0, int(optionValue.(float64)); counter < resultSetNum; counter++ { ret := odbc.SQLMoreResults(stmt.handle) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt)) } } } else { //If query option ResultSetNum was not passed, iterate through result sets until at least one column is found for { var numColumns odbc.SQLSMALLINT ret := odbc.SQLNumResultCols(stmt.handle, &numColumns) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt)) } if numColumns > 0 { break } else { ret := odbc.SQLMoreResults(stmt.handle) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v", stmt.sqlStmt)) } } } } //Get definition of result columns resultColumnDefs, ret := buildResultColumnDefinitions(stmt.handle, stmt.sqlStmt) if isError(ret) { return nil, errorStatement(stmt.handle, fmt.Sprintf("SQL Stmt: %v\nBind Values: %v", stmt.sqlStmt, stmt.formatBindValues())) } //Make a slice of the column names columnNames := make([]string, len(resultColumnDefs)) for index, resultCol := range resultColumnDefs { columnNames[index] = fmt.Sprint(resultCol.Name) } //Create rows stmt.rows = &rows{handle: stmt.handle, descHandle: descRowHandle, isBeforeFirst: true, resultColumnDefs: resultColumnDefs, resultColumnNames: columnNames, sqlStmt: stmt.sqlStmt} //Add a finalizer runtime.SetFinalizer(stmt.rows, (*rows).Close) return stmt.rows, nil }