Provider.cpp
上传用户:zhuzhu0204
上传日期:2020-07-13
资源大小:13165k
文件大小:8k
开发平台:

Visual C++

  1. // Provider.cpp
  2. // 声明要使用UNICODE字符串
  3. #define UNICODE
  4. #define _UNICODE
  5. #include "stdafx.h"
  6. #include <Ws2spi.h>
  7. #include <Sporder.h>
  8. #include "Provider.h"
  9. #pragma comment(lib, "Ws2_32.lib")
  10. #pragma comment(lib, "sporder.lib")
  11. // 要安装的LSP的硬编码,在移除的时候还要使用它
  12. GUID  ProviderGuid = {0xd3c21122, 0x85e1, 0x48f3, {0x9a,0xb6,0x23,0xd9,0x0c,0x73,0x07,0xef}};
  13. LPWSAPROTOCOL_INFOW GetProvider(LPINT lpnTotalProtocols)
  14. {
  15. DWORD dwSize = 0;
  16. int nError;
  17. LPWSAPROTOCOL_INFOW pProtoInfo = NULL;
  18. // 取得需要的长度
  19. if(::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError) == SOCKET_ERROR)
  20. {
  21. if(nError != WSAENOBUFS)
  22. return NULL;
  23. }
  24. pProtoInfo = (LPWSAPROTOCOL_INFOW)::GlobalAlloc(GPTR, dwSize);
  25. *lpnTotalProtocols = ::WSCEnumProtocols(NULL, pProtoInfo, &dwSize, &nError);
  26. return pProtoInfo;
  27. }
  28. void FreeProvider(LPWSAPROTOCOL_INFOW pProtoInfo)
  29. {
  30. ::GlobalFree(pProtoInfo);
  31. }
  32. BOOL InstallProvider(WCHAR *pwszPathName)
  33. {
  34. WCHAR wszLSPName[] = L"MyLsp";
  35. LPWSAPROTOCOL_INFOW pProtoInfo;
  36. int nProtocols;
  37. WSAPROTOCOL_INFOW OriginalProtocolInfo[3];
  38. DWORD  dwOrigCatalogId[3];
  39. int nArrayCount = 0;
  40. DWORD dwLayeredCatalogId; // 我们分层协议的目录ID号
  41. int nError;
  42. // 找到我们的下层协议,将信息放入数组中
  43. // 枚举所有服务程序提供者
  44. pProtoInfo = GetProvider(&nProtocols);
  45. BOOL bFindUdp = FALSE;
  46. BOOL bFindTcp = FALSE;
  47. BOOL bFindRaw = FALSE;
  48. for(int i=0; i<nProtocols; i++)
  49. {
  50. if(pProtoInfo[i].iAddressFamily == AF_INET)
  51. {
  52. if(!bFindUdp && pProtoInfo[i].iProtocol == IPPROTO_UDP)
  53. {
  54. memcpy(&OriginalProtocolInfo[nArrayCount], &pProtoInfo[i], sizeof(WSAPROTOCOL_INFOW));
  55. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 = 
  56. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 & (~XP1_IFS_HANDLES); 
  57. dwOrigCatalogId[nArrayCount++] = pProtoInfo[i].dwCatalogEntryId;
  58. bFindUdp = TRUE;
  59. }
  60. if(!bFindTcp && pProtoInfo[i].iProtocol == IPPROTO_TCP)
  61. {
  62. memcpy(&OriginalProtocolInfo[nArrayCount], &pProtoInfo[i], sizeof(WSAPROTOCOL_INFOW));
  63. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 = 
  64. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 & (~XP1_IFS_HANDLES); 
  65. dwOrigCatalogId[nArrayCount++] = pProtoInfo[i].dwCatalogEntryId;
  66. bFindTcp = TRUE;
  67. if(!bFindRaw && pProtoInfo[i].iProtocol == IPPROTO_IP)
  68. {
  69. memcpy(&OriginalProtocolInfo[nArrayCount], &pProtoInfo[i], sizeof(WSAPROTOCOL_INFOW));
  70. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 = 
  71. OriginalProtocolInfo[nArrayCount].dwServiceFlags1 & (~XP1_IFS_HANDLES); 
  72. dwOrigCatalogId[nArrayCount++] = pProtoInfo[i].dwCatalogEntryId;
  73. bFindRaw = TRUE;
  74. }
  75. }
  76. }  
  77. // 安装我们的分层协议,获取一个dwLayeredCatalogId
  78. // 随便找一个下层协议的结构复制过来即可
  79. WSAPROTOCOL_INFOW LayeredProtocolInfo;
  80. memcpy(&LayeredProtocolInfo, &OriginalProtocolInfo[0], sizeof(WSAPROTOCOL_INFOW));
  81. // 修改协议名称,类型,设置PFL_HIDDEN标志
  82. wcscpy(LayeredProtocolInfo.szProtocol, wszLSPName);
  83. LayeredProtocolInfo.ProtocolChain.ChainLen = LAYERED_PROTOCOL; // 0;
  84. LayeredProtocolInfo.dwProviderFlags |= PFL_HIDDEN;
  85. // 安装
  86. if(::WSCInstallProvider(&ProviderGuid, 
  87. pwszPathName, &LayeredProtocolInfo, 1, &nError) == SOCKET_ERROR)
  88. {
  89. return FALSE;
  90. }
  91. // 重新枚举协议,获取分层协议的目录ID号
  92. FreeProvider(pProtoInfo);
  93. pProtoInfo = GetProvider(&nProtocols);
  94. for(i=0; i<nProtocols; i++)
  95. {
  96. if(memcmp(&pProtoInfo[i].ProviderId, &ProviderGuid, sizeof(ProviderGuid)) == 0)
  97. {
  98. dwLayeredCatalogId = pProtoInfo[i].dwCatalogEntryId;
  99. break;
  100. }
  101. }
  102. // 安装协议链
  103. // 修改协议名称,类型
  104. WCHAR wszChainName[WSAPROTOCOL_LEN + 1];
  105. for(i=0; i<nArrayCount; i++)
  106. {
  107. swprintf(wszChainName, L"%ws", wszLSPName);
  108. wcscpy(OriginalProtocolInfo[i].szProtocol, wszChainName);
  109. if(OriginalProtocolInfo[i].ProtocolChain.ChainLen == 1)
  110. {
  111. OriginalProtocolInfo[i].ProtocolChain.ChainEntries[1] = dwOrigCatalogId[i];
  112. }
  113. else
  114. {
  115. for(int j = OriginalProtocolInfo[i].ProtocolChain.ChainLen; j>0; j--)
  116. {
  117. OriginalProtocolInfo[i].ProtocolChain.ChainEntries[j] 
  118. = OriginalProtocolInfo[i].ProtocolChain.ChainEntries[j-1];
  119. }
  120. }
  121. OriginalProtocolInfo[i].ProtocolChain.ChainLen ++;
  122. OriginalProtocolInfo[i].ProtocolChain.ChainEntries[0] = dwLayeredCatalogId;
  123. }
  124. // 获取一个Guid,安装之
  125. GUID ProviderChainGuid;
  126. if(::UuidCreate(&ProviderChainGuid) == RPC_S_OK)
  127. {
  128. if(::WSCInstallProvider(&ProviderChainGuid, 
  129. pwszPathName, OriginalProtocolInfo, nArrayCount, &nError) == SOCKET_ERROR)
  130. {
  131. return FALSE;
  132. }
  133. }
  134. else
  135. return FALSE;
  136. // 将我们的协议提前,重新排序Winsock目录
  137. // 重新枚举安装的协议
  138. FreeProvider(pProtoInfo);
  139. pProtoInfo = GetProvider(&nProtocols);
  140. DWORD dwIds[20];
  141. int nIndex = 0;
  142. // 添加我们的协议链
  143. for(i=0; i<nProtocols; i++)
  144. {
  145. if((pProtoInfo[i].ProtocolChain.ChainLen > 1) &&
  146. (pProtoInfo[i].ProtocolChain.ChainEntries[0] == dwLayeredCatalogId))
  147. dwIds[nIndex++] = pProtoInfo[i].dwCatalogEntryId;
  148. }
  149. // 添加其它协议
  150. for(i=0; i<nProtocols; i++)
  151. {
  152. if((pProtoInfo[i].ProtocolChain.ChainLen <= 1) ||
  153. (pProtoInfo[i].ProtocolChain.ChainEntries[0] != dwLayeredCatalogId))
  154. dwIds[nIndex++] = pProtoInfo[i].dwCatalogEntryId;
  155. }
  156. // 重新排序Winsock目录
  157. if(nError = ::WSCWriteProviderOrder(dwIds, nIndex) != ERROR_SUCCESS)
  158. {
  159. return FALSE;
  160. }
  161. FreeProvider(pProtoInfo);
  162. return TRUE;
  163. }
  164. BOOL RemoveProvider()
  165. {
  166. LPWSAPROTOCOL_INFOW pProtoInfo;
  167. int nProtocols;
  168. DWORD dwLayeredCatalogId;
  169. // 根据Guid取得分层协议的目录ID号
  170. pProtoInfo = GetProvider(&nProtocols);
  171. int nError;
  172. for(int i=0; i<nProtocols; i++)
  173. {
  174. if(memcmp(&ProviderGuid, &pProtoInfo[i].ProviderId, sizeof(ProviderGuid)) == 0)
  175. {
  176. dwLayeredCatalogId = pProtoInfo[i].dwCatalogEntryId;
  177. break;
  178. }
  179. }
  180. if(i < nProtocols)
  181. {
  182. // 移除协议链
  183. for(i=0; i<nProtocols; i++)
  184. {
  185. if((pProtoInfo[i].ProtocolChain.ChainLen > 1) &&
  186. (pProtoInfo[i].ProtocolChain.ChainEntries[0] == dwLayeredCatalogId))
  187. {
  188. ::WSCDeinstallProvider(&pProtoInfo[i].ProviderId, &nError);
  189. }
  190. }
  191. // 移除分层协议
  192. ::WSCDeinstallProvider(&ProviderGuid, &nError);
  193. }
  194. return TRUE;
  195. }
  196. BOOL IsProviderInstalled()
  197. {
  198. WCHAR wszLSPName[] = L"MyLsp";
  199. LPWSAPROTOCOL_INFOW pProtoInfo;
  200. int nProtocols;
  201. pProtoInfo=GetProvider(&nProtocols);
  202. for(int i=0;i<nProtocols;i++)
  203. {
  204. if (wcscmp(wszLSPName, pProtoInfo[i].szProtocol) == 0)
  205. {
  206. return TRUE;
  207. }
  208. }
  209. return FALSE;
  210. }
  211. /*///////////////////////////////////////////////////////////////////
  212. void RemoveAllLayeredEntries()
  213. {
  214.  BOOL    bLayer;
  215.     int     ErrorCode,
  216.             i;
  217. int TotalProtocols;
  218. LPWSAPROTOCOL_INFOW ProtocolInfo;
  219.     while (1)
  220.     {
  221.         bLayer = FALSE;
  222.         ProtocolInfo = GetProvider(&TotalProtocols);
  223.         if (!ProtocolInfo)
  224.         {
  225.             printf("Unable to enumerate Winsock catalog!n");
  226.             return;
  227.         }
  228.         for(i=0; i < TotalProtocols ;i++)
  229.         {
  230.             if (ProtocolInfo[i].ProtocolChain.ChainLen != BASE_PROTOCOL)
  231.             {
  232.                 bLayer = TRUE;
  233.                 printf("Removing '%S'n", ProtocolInfo[i].szProtocol);
  234.                 if (WSCDeinstallProvider(&ProtocolInfo[i].ProviderId, &ErrorCode) == SOCKET_ERROR)
  235.                 {
  236.                     printf("Failed to remove [%s]: Error %dn", ProtocolInfo[i].szProtocol, ErrorCode);
  237.                 }
  238.                 break;
  239.             }
  240.         }
  241.         FreeProvider(ProtocolInfo);
  242.         if (bLayer == FALSE)
  243.         {
  244.             break;
  245.         }
  246.     }
  247. }
  248. */