Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 79 additions & 32 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation
#endif
#define DAE_CHUNK_SIZE 8192
#define SQL_MAX_LOB_SIZE 8000
//-------------------------------------------------------------------------------------------------
// Class definitions
//-------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1747,8 +1748,13 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt,
&actualRead);

if (ret == SQL_ERROR || !SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) {
LOG("Loop {}: Error fetching column {} with cType={}", loopCount, colIndex, cType);
ThrowStdException("Error fetching column data");
std::ostringstream oss;
oss << "Error fetching LOB for column " << colIndex
<< ", cType=" << cType
<< ", loop=" << loopCount
<< ", SQLGetData return=" << ret;
LOG(oss.str());
ThrowStdException(oss.str());
}
if (actualRead == SQL_NULL_DATA) {
LOG("Loop {}: Column {} is NULL", loopCount, colIndex);
Expand Down Expand Up @@ -1862,7 +1868,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR: {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 8000) {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) {
LOG("Streaming LOB for column {}", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false));
} else {
Expand Down Expand Up @@ -2406,7 +2412,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// Fetch rows in batches
// TODO: Move to anonymous namespace, since it is not used outside this file
SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames,
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) {
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector<SQLUSMALLINT>& lobColumns) {
LOG("Fetching data in batches");
SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0);
if (ret == SQL_NO_DATA) {
Expand Down Expand Up @@ -2466,25 +2472,19 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR: {
// TODO: variable length data needs special handling, this logic wont suffice
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
uint64_t numCharsInData = dataLen / sizeof(SQLCHAR);
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
if (numCharsInData < fetchBufferSize) {
if (!isLob && numCharsInData < fetchBufferSize) {
// SQLFetch will nullterminate the data
row.append(std::string(
reinterpret_cast<char*>(&buffers.charBuffers[col - 1][i * fetchBufferSize]),
numCharsInData));
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << numCharsInData << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false));
}
break;
}
Expand All @@ -2496,8 +2496,9 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
if (numCharsInData < fetchBufferSize) {
if (!isLob && numCharsInData < fetchBufferSize) {
// SQLFetch will nullterminate the data
#if defined(__APPLE__) || defined(__linux__)
// Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference
Expand All @@ -2511,13 +2512,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
numCharsInData));
#endif
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << numCharsInData << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false));
}
break;
}
Expand Down Expand Up @@ -2603,21 +2598,15 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
case SQL_BINARY:
case SQL_VARBINARY:
case SQL_LONGVARBINARY: {
// TODO: variable length data needs special handling, this logic wont suffice
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
HandleZeroColumnSizeAtFetch(columnSize);
if (static_cast<size_t>(dataLen) <= columnSize) {
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
if (!isLob && static_cast<size_t>(dataLen) <= columnSize) {
row.append(py::bytes(reinterpret_cast<const char*>(
&buffers.charBuffers[col - 1][i * columnSize]),
dataLen));
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << dataLen << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true));
}
break;
}
Expand Down Expand Up @@ -2746,6 +2735,35 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
return ret;
}

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected → using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
if (!SQL_SUCCEEDED(ret)) return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

// Initialize column buffers
ColumnBuffers buffers(numCols, fetchSize);

Expand All @@ -2760,7 +2778,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("Error when fetching data");
return ret;
Expand Down Expand Up @@ -2839,6 +2857,35 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
}
LOG("Fetching data in batch sizes of {}", fetchSize);

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected → using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
if (!SQL_SUCCEEDED(ret)) return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

ColumnBuffers buffers(numCols, fetchSize);

// Bind columns
Expand All @@ -2853,7 +2900,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

while (ret != SQL_NO_DATA) {
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("Error when fetching data");
return ret;
Expand Down
Loading