Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 125 additions & 83 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define ARCHITECTURE "win64" // Default to win64 if not defined during compilation
#endif
#define DAE_CHUNK_SIZE 8192
#define SQL_MAX_LOB_SIZE 8000
//-------------------------------------------------------------------------------------------------
// Class definitions
//-------------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1747,8 +1748,13 @@ static py::object FetchLobColumnData(SQLHSTMT hStmt,
&actualRead);

if (ret == SQL_ERROR || !SQL_SUCCEEDED(ret) && ret != SQL_SUCCESS_WITH_INFO) {
LOG("Loop {}: Error fetching column {} with cType={}", loopCount, colIndex, cType);
ThrowStdException("Error fetching column data");
std::ostringstream oss;
oss << "Error fetching LOB for column " << colIndex
<< ", cType=" << cType
<< ", loop=" << loopCount
<< ", SQLGetData return=" << ret;
LOG(oss.str());
ThrowStdException(oss.str());
}
if (actualRead == SQL_NULL_DATA) {
LOG("Loop {}: Column {} is NULL", loopCount, colIndex);
Expand Down Expand Up @@ -1862,7 +1868,7 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR: {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 8000) {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) {
LOG("Streaming LOB for column {}", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false));
} else {
Expand All @@ -1884,6 +1890,10 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
#else
row.append(std::string(reinterpret_cast<char*>(dataBuffer.data())));
#endif
} else {
// Buffer too small, fallback to streaming
LOG("CHAR column {} data truncated, using streaming LOB", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false));
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("Column {} is NULL (CHAR)", i);
Expand Down Expand Up @@ -1911,62 +1921,53 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p
}
case SQL_WCHAR:
case SQL_WVARCHAR:
case SQL_WLONGVARCHAR: {
// TODO: revisit
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /* null-termination */;
std::vector<SQLWCHAR> dataBuffer(fetchBufferSize);
SQLLEN dataLen;
ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(),
dataBuffer.size() * sizeof(SQLWCHAR), &dataLen);

if (SQL_SUCCEEDED(ret)) {
// TODO: Refactor these if's across other switches to avoid code duplication
if (dataLen > 0) {
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
if (numCharsInData < dataBuffer.size()) {
// SQLGetData will null-terminate the data
case SQL_WLONGVARCHAR: {
if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > 4000) {
LOG("Streaming LOB for column {} (NVARCHAR)", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false));
} else {
uint64_t fetchBufferSize = (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator
std::vector<SQLWCHAR> dataBuffer(columnSize + 1);
SQLLEN dataLen;
ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(), fetchBufferSize, &dataLen);
if (SQL_SUCCEEDED(ret)) {
if (dataLen > 0) {
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
if (numCharsInData < dataBuffer.size()) {
#if defined(__APPLE__) || defined(__linux__)
auto raw_bytes = reinterpret_cast<const char*>(dataBuffer.data());
size_t actualBufferSize = dataBuffer.size() * sizeof(SQLWCHAR);
if (dataLen < 0 || static_cast<size_t>(dataLen) > actualBufferSize) {
LOG("Error: py::bytes creation request exceeds buffer size. dataLen={} buffer={}",
dataLen, actualBufferSize);
ThrowStdException("Invalid buffer length for py::bytes");
}
py::bytes py_bytes(raw_bytes, dataLen);
py::str decoded = py_bytes.attr("decode")("utf-16-le");
row.append(decoded);
const SQLWCHAR* sqlwBuf = reinterpret_cast<const SQLWCHAR*>(dataBuffer.data());
std::wstring wstr = SQLWCHARToWString(sqlwBuf, numCharsInData);
std::string utf8str = WideToUTF8(wstr);
row.append(py::str(utf8str));
#else
row.append(std::wstring(dataBuffer.data()));
std::wstring wstr(reinterpret_cast<wchar_t*>(dataBuffer.data()));
row.append(py::cast(wstr));
#endif
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << dataBuffer.size()-1 << ") is smaller, & data "
<< "to be retrieved is longer (" << numCharsInData << "). ColumnID - "
<< i << ", datatype - " << dataType;
ThrowStdException(oss.str());
LOG("Appended NVARCHAR string of length {} to result row", numCharsInData);
} else {
// Buffer too small, fallback to streaming
LOG("NVARCHAR column {} data truncated, using streaming LOB", i);
row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false));
}
} else if (dataLen == SQL_NULL_DATA) {
LOG("Column {} is NULL (CHAR)", i);
row.append(py::none());
} else if (dataLen == 0) {
row.append(py::str(""));
} else if (dataLen == SQL_NO_TOTAL) {
LOG("SQLGetData couldn't determine the length of the NVARCHAR data. Returning NULL. Column ID - {}", i);
row.append(py::none());
} else if (dataLen < 0) {
LOG("SQLGetData returned an unexpected negative data length. "
"Raising exception. Column ID - {}, Data Type - {}, Data Length - {}",
i, dataType, dataLen);
ThrowStdException("SQLGetData returned an unexpected negative data length");
}
} else if (dataLen == SQL_NULL_DATA) {
row.append(py::none());
} else if (dataLen == 0) {
// Handle zero-length (non-NULL) data
row.append(py::str(""));
} else if (dataLen < 0) {
// This is unexpected
LOG("SQLGetData returned an unexpected negative data length. "
"Raising exception. Column ID - {}, Data Type - {}, Data Length - {}",
i, dataType, dataLen);
ThrowStdException("SQLGetData returned an unexpected negative data length");
} else {
LOG("Error retrieving data for column {} (NVARCHAR), SQLGetData return code {}", i, ret);
row.append(py::none());
}
} else {
LOG("Error retrieving data for column - {}, data type - {}, SQLGetData return "
"code - {}. Returning NULL value instead",
i, dataType, ret);
row.append(py::none());
}
}
break;
}
case SQL_INTEGER: {
Expand Down Expand Up @@ -2411,7 +2412,7 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
// Fetch rows in batches
// TODO: Move to anonymous namespace, since it is not used outside this file
SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames,
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched) {
py::list& rows, SQLUSMALLINT numCols, SQLULEN& numRowsFetched, const std::vector<SQLUSMALLINT>& lobColumns) {
LOG("Fetching data in batches");
SQLRETURN ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0);
if (ret == SQL_NO_DATA) {
Expand Down Expand Up @@ -2471,25 +2472,19 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
case SQL_CHAR:
case SQL_VARCHAR:
case SQL_LONGVARCHAR: {
// TODO: variable length data needs special handling, this logic wont suffice
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
uint64_t numCharsInData = dataLen / sizeof(SQLCHAR);
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
if (numCharsInData < fetchBufferSize) {
if (!isLob && numCharsInData < fetchBufferSize) {
// SQLFetch will nullterminate the data
row.append(std::string(
reinterpret_cast<char*>(&buffers.charBuffers[col - 1][i * fetchBufferSize]),
numCharsInData));
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << numCharsInData << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false));
}
break;
}
Expand All @@ -2501,8 +2496,9 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
HandleZeroColumnSizeAtFetch(columnSize);
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
if (numCharsInData < fetchBufferSize) {
if (!isLob && numCharsInData < fetchBufferSize) {
// SQLFetch will nullterminate the data
#if defined(__APPLE__) || defined(__linux__)
// Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference
Expand All @@ -2516,13 +2512,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
numCharsInData));
#endif
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << numCharsInData << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false));
}
break;
}
Expand Down Expand Up @@ -2608,21 +2598,15 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
case SQL_BINARY:
case SQL_VARBINARY:
case SQL_LONGVARBINARY: {
// TODO: variable length data needs special handling, this logic wont suffice
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
HandleZeroColumnSizeAtFetch(columnSize);
if (static_cast<size_t>(dataLen) <= columnSize) {
bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end();
if (!isLob && static_cast<size_t>(dataLen) <= columnSize) {
row.append(py::bytes(reinterpret_cast<const char*>(
&buffers.charBuffers[col - 1][i * columnSize]),
dataLen));
} else {
// In this case, buffer size is smaller, and data to be retrieved is longer
// TODO: Revisit
std::ostringstream oss;
oss << "Buffer length for fetch (" << columnSize << ") is smaller, & data "
<< "to be retrieved is longer (" << dataLen << "). ColumnID - "
<< col << ", datatype - " << dataType;
ThrowStdException(oss.str());
row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true));
}
break;
}
Expand Down Expand Up @@ -2751,6 +2735,35 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
return ret;
}

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected → using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
if (!SQL_SUCCEEDED(ret)) return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

// Initialize column buffers
ColumnBuffers buffers(numCols, fetchSize);

Expand All @@ -2765,7 +2778,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0);
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("Error when fetching data");
return ret;
Expand Down Expand Up @@ -2844,6 +2857,35 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
}
LOG("Fetching data in batch sizes of {}", fetchSize);

std::vector<SQLUSMALLINT> lobColumns;
for (SQLSMALLINT i = 0; i < numCols; i++) {
auto colMeta = columnNames[i].cast<py::dict>();
SQLSMALLINT dataType = colMeta["DataType"].cast<SQLSMALLINT>();
SQLULEN columnSize = colMeta["ColumnSize"].cast<SQLULEN>();

if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR ||
dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR ||
dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) &&
(columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) {
lobColumns.push_back(i + 1); // 1-based
}
}

// If we have LOBs → fall back to row-by-row fetch + SQLGetData_wrap
if (!lobColumns.empty()) {
LOG("LOB columns detected → using per-row SQLGetData path");
while (true) {
ret = SQLFetch_ptr(hStmt);
if (ret == SQL_NO_DATA) break;
if (!SQL_SUCCEEDED(ret)) return ret;

py::list row;
SQLGetData_wrap(StatementHandle, numCols, row); // <-- streams LOBs correctly
rows.append(row);
}
return SQL_SUCCESS;
}

ColumnBuffers buffers(numCols, fetchSize);

// Bind columns
Expand All @@ -2858,7 +2900,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows) {
SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0);

while (ret != SQL_NO_DATA) {
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched);
ret = FetchBatchData(hStmt, buffers, columnNames, rows, numCols, numRowsFetched, lobColumns);
if (!SQL_SUCCEEDED(ret) && ret != SQL_NO_DATA) {
LOG("Error when fetching data");
return ret;
Expand Down
Loading