diff options
Diffstat (limited to 'src/core/hle/service/ssl')
| -rw-r--r-- | src/core/hle/service/ssl/ssl.cpp | 349 | ||||
| -rw-r--r-- | src/core/hle/service/ssl/ssl_backend.h | 44 | ||||
| -rw-r--r-- | src/core/hle/service/ssl/ssl_backend_none.cpp | 15 | ||||
| -rw-r--r-- | src/core/hle/service/ssl/ssl_backend_openssl.cpp | 342 | ||||
| -rw-r--r-- | src/core/hle/service/ssl/ssl_backend_schannel.cpp | 529 |
5 files changed, 1262 insertions, 17 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 2b99dd7ac..a3b54c7f0 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp | |||
| @@ -1,10 +1,18 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project | 1 | // SPDX-FileCopyrightText: Copyright 2018 yuzu Emulator Project |
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | 2 | // SPDX-License-Identifier: GPL-2.0-or-later |
| 3 | 3 | ||
| 4 | #include "common/string_util.h" | ||
| 5 | |||
| 6 | #include "core/core.h" | ||
| 4 | #include "core/hle/service/ipc_helpers.h" | 7 | #include "core/hle/service/ipc_helpers.h" |
| 5 | #include "core/hle/service/server_manager.h" | 8 | #include "core/hle/service/server_manager.h" |
| 6 | #include "core/hle/service/service.h" | 9 | #include "core/hle/service/service.h" |
| 10 | #include "core/hle/service/sm/sm.h" | ||
| 11 | #include "core/hle/service/sockets/bsd.h" | ||
| 7 | #include "core/hle/service/ssl/ssl.h" | 12 | #include "core/hle/service/ssl/ssl.h" |
| 13 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 14 | #include "core/internal_network/network.h" | ||
| 15 | #include "core/internal_network/sockets.h" | ||
| 8 | 16 | ||
| 9 | namespace Service::SSL { | 17 | namespace Service::SSL { |
| 10 | 18 | ||
| @@ -20,6 +28,18 @@ enum class ContextOption : u32 { | |||
| 20 | CrlImportDateCheckEnable = 1, | 28 | CrlImportDateCheckEnable = 1, |
| 21 | }; | 29 | }; |
| 22 | 30 | ||
| 31 | // This is nn::ssl::Connection::IoMode | ||
| 32 | enum class IoMode : u32 { | ||
| 33 | Blocking = 1, | ||
| 34 | NonBlocking = 2, | ||
| 35 | }; | ||
| 36 | |||
| 37 | // This is nn::ssl::sf::OptionType | ||
| 38 | enum class OptionType : u32 { | ||
| 39 | DoNotCloseSocket = 0, | ||
| 40 | GetServerCertChain = 1, | ||
| 41 | }; | ||
| 42 | |||
| 23 | // This is nn::ssl::sf::SslVersion | 43 | // This is nn::ssl::sf::SslVersion |
| 24 | struct SslVersion { | 44 | struct SslVersion { |
| 25 | union { | 45 | union { |
| @@ -34,35 +54,42 @@ struct SslVersion { | |||
| 34 | }; | 54 | }; |
| 35 | }; | 55 | }; |
| 36 | 56 | ||
| 57 | struct SslContextSharedData { | ||
| 58 | u32 connection_count = 0; | ||
| 59 | }; | ||
| 60 | |||
| 37 | class ISslConnection final : public ServiceFramework<ISslConnection> { | 61 | class ISslConnection final : public ServiceFramework<ISslConnection> { |
| 38 | public: | 62 | public: |
| 39 | explicit ISslConnection(Core::System& system_, SslVersion version) | 63 | explicit ISslConnection(Core::System& system_, SslVersion version, |
| 40 | : ServiceFramework{system_, "ISslConnection"}, ssl_version{version} { | 64 | std::shared_ptr<SslContextSharedData>& shared_data, |
| 65 | std::unique_ptr<SSLConnectionBackend>&& backend) | ||
| 66 | : ServiceFramework{system_, "ISslConnection"}, ssl_version{version}, | ||
| 67 | shared_data_{shared_data}, backend_{std::move(backend)} { | ||
| 41 | // clang-format off | 68 | // clang-format off |
| 42 | static const FunctionInfo functions[] = { | 69 | static const FunctionInfo functions[] = { |
| 43 | {0, nullptr, "SetSocketDescriptor"}, | 70 | {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, |
| 44 | {1, nullptr, "SetHostName"}, | 71 | {1, &ISslConnection::SetHostName, "SetHostName"}, |
| 45 | {2, nullptr, "SetVerifyOption"}, | 72 | {2, &ISslConnection::SetVerifyOption, "SetVerifyOption"}, |
| 46 | {3, nullptr, "SetIoMode"}, | 73 | {3, &ISslConnection::SetIoMode, "SetIoMode"}, |
| 47 | {4, nullptr, "GetSocketDescriptor"}, | 74 | {4, nullptr, "GetSocketDescriptor"}, |
| 48 | {5, nullptr, "GetHostName"}, | 75 | {5, nullptr, "GetHostName"}, |
| 49 | {6, nullptr, "GetVerifyOption"}, | 76 | {6, nullptr, "GetVerifyOption"}, |
| 50 | {7, nullptr, "GetIoMode"}, | 77 | {7, nullptr, "GetIoMode"}, |
| 51 | {8, nullptr, "DoHandshake"}, | 78 | {8, &ISslConnection::DoHandshake, "DoHandshake"}, |
| 52 | {9, nullptr, "DoHandshakeGetServerCert"}, | 79 | {9, &ISslConnection::DoHandshakeGetServerCert, "DoHandshakeGetServerCert"}, |
| 53 | {10, nullptr, "Read"}, | 80 | {10, &ISslConnection::Read, "Read"}, |
| 54 | {11, nullptr, "Write"}, | 81 | {11, &ISslConnection::Write, "Write"}, |
| 55 | {12, nullptr, "Pending"}, | 82 | {12, &ISslConnection::Pending, "Pending"}, |
| 56 | {13, nullptr, "Peek"}, | 83 | {13, nullptr, "Peek"}, |
| 57 | {14, nullptr, "Poll"}, | 84 | {14, nullptr, "Poll"}, |
| 58 | {15, nullptr, "GetVerifyCertError"}, | 85 | {15, nullptr, "GetVerifyCertError"}, |
| 59 | {16, nullptr, "GetNeededServerCertBufferSize"}, | 86 | {16, nullptr, "GetNeededServerCertBufferSize"}, |
| 60 | {17, nullptr, "SetSessionCacheMode"}, | 87 | {17, &ISslConnection::SetSessionCacheMode, "SetSessionCacheMode"}, |
| 61 | {18, nullptr, "GetSessionCacheMode"}, | 88 | {18, nullptr, "GetSessionCacheMode"}, |
| 62 | {19, nullptr, "FlushSessionCache"}, | 89 | {19, nullptr, "FlushSessionCache"}, |
| 63 | {20, nullptr, "SetRenegotiationMode"}, | 90 | {20, nullptr, "SetRenegotiationMode"}, |
| 64 | {21, nullptr, "GetRenegotiationMode"}, | 91 | {21, nullptr, "GetRenegotiationMode"}, |
| 65 | {22, nullptr, "SetOption"}, | 92 | {22, &ISslConnection::SetOption, "SetOption"}, |
| 66 | {23, nullptr, "GetOption"}, | 93 | {23, nullptr, "GetOption"}, |
| 67 | {24, nullptr, "GetVerifyCertErrors"}, | 94 | {24, nullptr, "GetVerifyCertErrors"}, |
| 68 | {25, nullptr, "GetCipherInfo"}, | 95 | {25, nullptr, "GetCipherInfo"}, |
| @@ -80,21 +107,295 @@ public: | |||
| 80 | // clang-format on | 107 | // clang-format on |
| 81 | 108 | ||
| 82 | RegisterHandlers(functions); | 109 | RegisterHandlers(functions); |
| 110 | |||
| 111 | shared_data->connection_count++; | ||
| 112 | } | ||
| 113 | |||
| 114 | ~ISslConnection() { | ||
| 115 | shared_data_->connection_count--; | ||
| 116 | if (fd_to_close_.has_value()) { | ||
| 117 | s32 fd = *fd_to_close_; | ||
| 118 | if (!do_not_close_socket_) { | ||
| 119 | LOG_ERROR(Service_SSL, | ||
| 120 | "do_not_close_socket was changed after setting socket; is this right?"); | ||
| 121 | } else { | ||
| 122 | auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||
| 123 | if (bsd) { | ||
| 124 | auto err = bsd->CloseImpl(fd); | ||
| 125 | if (err != Service::Sockets::Errno::SUCCESS) { | ||
| 126 | LOG_ERROR(Service_SSL, "failed to close duplicated socket: {}", err); | ||
| 127 | } | ||
| 128 | } | ||
| 129 | } | ||
| 130 | } | ||
| 83 | } | 131 | } |
| 84 | 132 | ||
| 85 | private: | 133 | private: |
| 86 | SslVersion ssl_version; | 134 | SslVersion ssl_version; |
| 135 | std::shared_ptr<SslContextSharedData> shared_data_; | ||
| 136 | std::unique_ptr<SSLConnectionBackend> backend_; | ||
| 137 | std::optional<int> fd_to_close_; | ||
| 138 | bool do_not_close_socket_ = false; | ||
| 139 | bool get_server_cert_chain_ = false; | ||
| 140 | std::shared_ptr<Network::SocketBase> socket_; | ||
| 141 | bool did_set_host_name_ = false; | ||
| 142 | bool did_handshake_ = false; | ||
| 143 | |||
| 144 | ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { | ||
| 145 | LOG_DEBUG(Service_SSL, "called, fd={}", fd); | ||
| 146 | ASSERT(!did_handshake_); | ||
| 147 | auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); | ||
| 148 | ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); | ||
| 149 | s32 ret_fd; | ||
| 150 | // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor | ||
| 151 | if (do_not_close_socket_) { | ||
| 152 | auto res = bsd->DuplicateSocketImpl(fd); | ||
| 153 | if (!res.has_value()) { | ||
| 154 | LOG_ERROR(Service_SSL, "failed to duplicate socket"); | ||
| 155 | return ResultInvalidSocket; | ||
| 156 | } | ||
| 157 | fd = *res; | ||
| 158 | fd_to_close_ = fd; | ||
| 159 | ret_fd = fd; | ||
| 160 | } else { | ||
| 161 | ret_fd = -1; | ||
| 162 | } | ||
| 163 | std::optional<std::shared_ptr<Network::SocketBase>> sock = bsd->GetSocket(fd); | ||
| 164 | if (!sock.has_value()) { | ||
| 165 | LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); | ||
| 166 | return ResultInvalidSocket; | ||
| 167 | } | ||
| 168 | socket_ = std::move(*sock); | ||
| 169 | backend_->SetSocket(socket_); | ||
| 170 | return ret_fd; | ||
| 171 | } | ||
| 172 | |||
| 173 | Result SetHostNameImpl(const std::string& hostname) { | ||
| 174 | LOG_DEBUG(Service_SSL, "SetHostNameImpl({})", hostname); | ||
| 175 | ASSERT(!did_handshake_); | ||
| 176 | Result res = backend_->SetHostName(hostname); | ||
| 177 | if (res == ResultSuccess) { | ||
| 178 | did_set_host_name_ = true; | ||
| 179 | } | ||
| 180 | return res; | ||
| 181 | } | ||
| 182 | |||
| 183 | Result SetVerifyOptionImpl(u32 option) { | ||
| 184 | ASSERT(!did_handshake_); | ||
| 185 | LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); | ||
| 186 | return ResultSuccess; | ||
| 187 | } | ||
| 188 | |||
| 189 | Result SetIOModeImpl(u32 _mode) { | ||
| 190 | auto mode = static_cast<IoMode>(_mode); | ||
| 191 | ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); | ||
| 192 | ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; }); | ||
| 193 | |||
| 194 | bool non_block = mode == IoMode::NonBlocking; | ||
| 195 | Network::Errno e = socket_->SetNonBlock(non_block); | ||
| 196 | if (e != Network::Errno::SUCCESS) { | ||
| 197 | LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); | ||
| 198 | } | ||
| 199 | return ResultSuccess; | ||
| 200 | } | ||
| 201 | |||
| 202 | Result SetSessionCacheModeImpl(u32 mode) { | ||
| 203 | ASSERT(!did_handshake_); | ||
| 204 | LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); | ||
| 205 | return ResultSuccess; | ||
| 206 | } | ||
| 207 | |||
| 208 | Result DoHandshakeImpl() { | ||
| 209 | ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; }); | ||
| 210 | ASSERT_OR_EXECUTE_MSG( | ||
| 211 | did_set_host_name_, { return ResultInternalError; }, | ||
| 212 | "Expected SetHostName before DoHandshake"); | ||
| 213 | Result res = backend_->DoHandshake(); | ||
| 214 | did_handshake_ = res.IsSuccess(); | ||
| 215 | return res; | ||
| 216 | } | ||
| 217 | |||
| 218 | std::vector<u8> SerializeServerCerts(const std::vector<std::vector<u8>>& certs) { | ||
| 219 | struct Header { | ||
| 220 | u64 magic; | ||
| 221 | u32 count; | ||
| 222 | u32 pad; | ||
| 223 | }; | ||
| 224 | struct EntryHeader { | ||
| 225 | u32 size; | ||
| 226 | u32 offset; | ||
| 227 | }; | ||
| 228 | if (!get_server_cert_chain_) { | ||
| 229 | // Just return the first one, unencoded. | ||
| 230 | ASSERT_OR_EXECUTE_MSG( | ||
| 231 | !certs.empty(), { return {}; }, "Should be at least one server cert"); | ||
| 232 | return certs[0]; | ||
| 233 | } | ||
| 234 | std::vector<u8> ret; | ||
| 235 | Header header{0x4E4D684374726543, static_cast<u32>(certs.size()), 0}; | ||
| 236 | ret.insert(ret.end(), reinterpret_cast<u8*>(&header), reinterpret_cast<u8*>(&header + 1)); | ||
| 237 | size_t data_offset = sizeof(Header) + certs.size() * sizeof(EntryHeader); | ||
| 238 | for (auto& cert : certs) { | ||
| 239 | EntryHeader entry_header{static_cast<u32>(cert.size()), static_cast<u32>(data_offset)}; | ||
| 240 | data_offset += cert.size(); | ||
| 241 | ret.insert(ret.end(), reinterpret_cast<u8*>(&entry_header), | ||
| 242 | reinterpret_cast<u8*>(&entry_header + 1)); | ||
| 243 | } | ||
| 244 | for (auto& cert : certs) { | ||
| 245 | ret.insert(ret.end(), cert.begin(), cert.end()); | ||
| 246 | } | ||
| 247 | return ret; | ||
| 248 | } | ||
| 249 | |||
| 250 | ResultVal<std::vector<u8>> ReadImpl(size_t size) { | ||
| 251 | ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); | ||
| 252 | std::vector<u8> res(size); | ||
| 253 | ResultVal<size_t> actual = backend_->Read(res); | ||
| 254 | if (actual.Failed()) { | ||
| 255 | return actual.Code(); | ||
| 256 | } | ||
| 257 | res.resize(*actual); | ||
| 258 | return res; | ||
| 259 | } | ||
| 260 | |||
| 261 | ResultVal<size_t> WriteImpl(std::span<const u8> data) { | ||
| 262 | ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); | ||
| 263 | return backend_->Write(data); | ||
| 264 | } | ||
| 265 | |||
| 266 | ResultVal<s32> PendingImpl() { | ||
| 267 | LOG_WARNING(Service_SSL, "(STUBBED) called."); | ||
| 268 | return 0; | ||
| 269 | } | ||
| 270 | |||
| 271 | void SetSocketDescriptor(HLERequestContext& ctx) { | ||
| 272 | IPC::RequestParser rp{ctx}; | ||
| 273 | const s32 fd = rp.Pop<s32>(); | ||
| 274 | const ResultVal<s32> res = SetSocketDescriptorImpl(fd); | ||
| 275 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 276 | rb.Push(res.Code()); | ||
| 277 | rb.Push<s32>(res.ValueOr(-1)); | ||
| 278 | } | ||
| 279 | |||
| 280 | void SetHostName(HLERequestContext& ctx) { | ||
| 281 | const std::string hostname = Common::StringFromBuffer(ctx.ReadBuffer()); | ||
| 282 | const Result res = SetHostNameImpl(hostname); | ||
| 283 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 284 | rb.Push(res); | ||
| 285 | } | ||
| 286 | |||
| 287 | void SetVerifyOption(HLERequestContext& ctx) { | ||
| 288 | IPC::RequestParser rp{ctx}; | ||
| 289 | const u32 option = rp.Pop<u32>(); | ||
| 290 | const Result res = SetVerifyOptionImpl(option); | ||
| 291 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 292 | rb.Push(res); | ||
| 293 | } | ||
| 294 | |||
| 295 | void SetIoMode(HLERequestContext& ctx) { | ||
| 296 | IPC::RequestParser rp{ctx}; | ||
| 297 | const u32 mode = rp.Pop<u32>(); | ||
| 298 | const Result res = SetIOModeImpl(mode); | ||
| 299 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 300 | rb.Push(res); | ||
| 301 | } | ||
| 302 | |||
| 303 | void DoHandshake(HLERequestContext& ctx) { | ||
| 304 | const Result res = DoHandshakeImpl(); | ||
| 305 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 306 | rb.Push(res); | ||
| 307 | } | ||
| 308 | |||
| 309 | void DoHandshakeGetServerCert(HLERequestContext& ctx) { | ||
| 310 | Result res = DoHandshakeImpl(); | ||
| 311 | u32 certs_count = 0; | ||
| 312 | u32 certs_size = 0; | ||
| 313 | if (res == ResultSuccess) { | ||
| 314 | auto certs = backend_->GetServerCerts(); | ||
| 315 | if (certs.Succeeded()) { | ||
| 316 | std::vector<u8> certs_buf = SerializeServerCerts(*certs); | ||
| 317 | ctx.WriteBuffer(certs_buf); | ||
| 318 | certs_count = static_cast<u32>(certs->size()); | ||
| 319 | certs_size = static_cast<u32>(certs_buf.size()); | ||
| 320 | } | ||
| 321 | } | ||
| 322 | IPC::ResponseBuilder rb{ctx, 4}; | ||
| 323 | rb.Push(res); | ||
| 324 | rb.Push(certs_size); | ||
| 325 | rb.Push(certs_count); | ||
| 326 | } | ||
| 327 | |||
| 328 | void Read(HLERequestContext& ctx) { | ||
| 329 | const ResultVal<std::vector<u8>> res = ReadImpl(ctx.GetWriteBufferSize()); | ||
| 330 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 331 | rb.Push(res.Code()); | ||
| 332 | if (res.Succeeded()) { | ||
| 333 | rb.Push(static_cast<u32>(res->size())); | ||
| 334 | ctx.WriteBuffer(*res); | ||
| 335 | } else { | ||
| 336 | rb.Push(static_cast<u32>(0)); | ||
| 337 | } | ||
| 338 | } | ||
| 339 | |||
| 340 | void Write(HLERequestContext& ctx) { | ||
| 341 | const ResultVal<size_t> res = WriteImpl(ctx.ReadBuffer()); | ||
| 342 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 343 | rb.Push(res.Code()); | ||
| 344 | rb.Push(static_cast<u32>(res.ValueOr(0))); | ||
| 345 | } | ||
| 346 | |||
| 347 | void Pending(HLERequestContext& ctx) { | ||
| 348 | const ResultVal<s32> res = PendingImpl(); | ||
| 349 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 350 | rb.Push(res.Code()); | ||
| 351 | rb.Push<s32>(res.ValueOr(0)); | ||
| 352 | } | ||
| 353 | |||
| 354 | void SetSessionCacheMode(HLERequestContext& ctx) { | ||
| 355 | IPC::RequestParser rp{ctx}; | ||
| 356 | const u32 mode = rp.Pop<u32>(); | ||
| 357 | const Result res = SetSessionCacheModeImpl(mode); | ||
| 358 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 359 | rb.Push(res); | ||
| 360 | } | ||
| 361 | |||
| 362 | void SetOption(HLERequestContext& ctx) { | ||
| 363 | struct Parameters { | ||
| 364 | OptionType option; | ||
| 365 | s32 value; | ||
| 366 | }; | ||
| 367 | static_assert(sizeof(Parameters) == 0x8, "Parameters is an invalid size"); | ||
| 368 | |||
| 369 | IPC::RequestParser rp{ctx}; | ||
| 370 | const auto parameters = rp.PopRaw<Parameters>(); | ||
| 371 | |||
| 372 | switch (parameters.option) { | ||
| 373 | case OptionType::DoNotCloseSocket: | ||
| 374 | do_not_close_socket_ = static_cast<bool>(parameters.value); | ||
| 375 | break; | ||
| 376 | case OptionType::GetServerCertChain: | ||
| 377 | get_server_cert_chain_ = static_cast<bool>(parameters.value); | ||
| 378 | break; | ||
| 379 | default: | ||
| 380 | LOG_WARNING(Service_SSL, "unrecognized option={}, value={}", parameters.option, | ||
| 381 | parameters.value); | ||
| 382 | } | ||
| 383 | |||
| 384 | IPC::ResponseBuilder rb{ctx, 2}; | ||
| 385 | rb.Push(ResultSuccess); | ||
| 386 | } | ||
| 87 | }; | 387 | }; |
| 88 | 388 | ||
| 89 | class ISslContext final : public ServiceFramework<ISslContext> { | 389 | class ISslContext final : public ServiceFramework<ISslContext> { |
| 90 | public: | 390 | public: |
| 91 | explicit ISslContext(Core::System& system_, SslVersion version) | 391 | explicit ISslContext(Core::System& system_, SslVersion version) |
| 92 | : ServiceFramework{system_, "ISslContext"}, ssl_version{version} { | 392 | : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, |
| 393 | shared_data_{std::make_shared<SslContextSharedData>()} { | ||
| 93 | static const FunctionInfo functions[] = { | 394 | static const FunctionInfo functions[] = { |
| 94 | {0, &ISslContext::SetOption, "SetOption"}, | 395 | {0, &ISslContext::SetOption, "SetOption"}, |
| 95 | {1, nullptr, "GetOption"}, | 396 | {1, nullptr, "GetOption"}, |
| 96 | {2, &ISslContext::CreateConnection, "CreateConnection"}, | 397 | {2, &ISslContext::CreateConnection, "CreateConnection"}, |
| 97 | {3, nullptr, "GetConnectionCount"}, | 398 | {3, &ISslContext::GetConnectionCount, "GetConnectionCount"}, |
| 98 | {4, &ISslContext::ImportServerPki, "ImportServerPki"}, | 399 | {4, &ISslContext::ImportServerPki, "ImportServerPki"}, |
| 99 | {5, &ISslContext::ImportClientPki, "ImportClientPki"}, | 400 | {5, &ISslContext::ImportClientPki, "ImportClientPki"}, |
| 100 | {6, nullptr, "RemoveServerPki"}, | 401 | {6, nullptr, "RemoveServerPki"}, |
| @@ -111,6 +412,7 @@ public: | |||
| 111 | 412 | ||
| 112 | private: | 413 | private: |
| 113 | SslVersion ssl_version; | 414 | SslVersion ssl_version; |
| 415 | std::shared_ptr<SslContextSharedData> shared_data_; | ||
| 114 | 416 | ||
| 115 | void SetOption(HLERequestContext& ctx) { | 417 | void SetOption(HLERequestContext& ctx) { |
| 116 | struct Parameters { | 418 | struct Parameters { |
| @@ -130,11 +432,24 @@ private: | |||
| 130 | } | 432 | } |
| 131 | 433 | ||
| 132 | void CreateConnection(HLERequestContext& ctx) { | 434 | void CreateConnection(HLERequestContext& ctx) { |
| 133 | LOG_WARNING(Service_SSL, "(STUBBED) called"); | 435 | LOG_WARNING(Service_SSL, "called"); |
| 436 | |||
| 437 | auto backend_res = CreateSSLConnectionBackend(); | ||
| 134 | 438 | ||
| 135 | IPC::ResponseBuilder rb{ctx, 2, 0, 1}; | 439 | IPC::ResponseBuilder rb{ctx, 2, 0, 1}; |
| 440 | rb.Push(backend_res.Code()); | ||
| 441 | if (backend_res.Succeeded()) { | ||
| 442 | rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_, | ||
| 443 | std::move(*backend_res)); | ||
| 444 | } | ||
| 445 | } | ||
| 446 | |||
| 447 | void GetConnectionCount(HLERequestContext& ctx) { | ||
| 448 | LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count); | ||
| 449 | |||
| 450 | IPC::ResponseBuilder rb{ctx, 3}; | ||
| 136 | rb.Push(ResultSuccess); | 451 | rb.Push(ResultSuccess); |
| 137 | rb.PushIpcInterface<ISslConnection>(system, ssl_version); | 452 | rb.Push(shared_data_->connection_count); |
| 138 | } | 453 | } |
| 139 | 454 | ||
| 140 | void ImportServerPki(HLERequestContext& ctx) { | 455 | void ImportServerPki(HLERequestContext& ctx) { |
diff --git a/src/core/hle/service/ssl/ssl_backend.h b/src/core/hle/service/ssl/ssl_backend.h new file mode 100644 index 000000000..624e07d41 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend.h | |||
| @@ -0,0 +1,44 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #pragma once | ||
| 5 | |||
| 6 | #include "core/hle/result.h" | ||
| 7 | |||
| 8 | #include "common/common_types.h" | ||
| 9 | |||
| 10 | #include <memory> | ||
| 11 | #include <span> | ||
| 12 | #include <string> | ||
| 13 | #include <vector> | ||
| 14 | |||
| 15 | namespace Network { | ||
| 16 | class SocketBase; | ||
| 17 | } | ||
| 18 | |||
| 19 | namespace Service::SSL { | ||
| 20 | |||
| 21 | constexpr Result ResultNoSocket{ErrorModule::SSLSrv, 103}; | ||
| 22 | constexpr Result ResultInvalidSocket{ErrorModule::SSLSrv, 106}; | ||
| 23 | constexpr Result ResultTimeout{ErrorModule::SSLSrv, 205}; | ||
| 24 | constexpr Result ResultInternalError{ErrorModule::SSLSrv, 999}; // made up | ||
| 25 | |||
| 26 | constexpr Result ResultWouldBlock{ErrorModule::SSLSrv, 204}; | ||
| 27 | // ^ ResultWouldBlock is returned from Read and Write, and oddly, DoHandshake, | ||
| 28 | // with no way in the latter case to distinguish whether the client should poll | ||
| 29 | // for read or write. The one official client I've seen handles this by always | ||
| 30 | // polling for read (with a timeout). | ||
| 31 | |||
| 32 | class SSLConnectionBackend { | ||
| 33 | public: | ||
| 34 | virtual void SetSocket(std::shared_ptr<Network::SocketBase> socket) = 0; | ||
| 35 | virtual Result SetHostName(const std::string& hostname) = 0; | ||
| 36 | virtual Result DoHandshake() = 0; | ||
| 37 | virtual ResultVal<size_t> Read(std::span<u8> data) = 0; | ||
| 38 | virtual ResultVal<size_t> Write(std::span<const u8> data) = 0; | ||
| 39 | virtual ResultVal<std::vector<std::vector<u8>>> GetServerCerts() = 0; | ||
| 40 | }; | ||
| 41 | |||
| 42 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend(); | ||
| 43 | |||
| 44 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_none.cpp b/src/core/hle/service/ssl/ssl_backend_none.cpp new file mode 100644 index 000000000..eb01561e2 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_none.cpp | |||
| @@ -0,0 +1,15 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | |||
| 6 | #include "common/logging/log.h" | ||
| 7 | |||
| 8 | namespace Service::SSL { | ||
| 9 | |||
| 10 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 11 | LOG_ERROR(Service_SSL, "No SSL backend on this platform"); | ||
| 12 | return ResultInternalError; | ||
| 13 | } | ||
| 14 | |||
| 15 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp new file mode 100644 index 000000000..cf9b904ac --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp | |||
| @@ -0,0 +1,342 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | #include "core/internal_network/network.h" | ||
| 6 | #include "core/internal_network/sockets.h" | ||
| 7 | |||
| 8 | #include "common/fs/file.h" | ||
| 9 | #include "common/hex_util.h" | ||
| 10 | #include "common/string_util.h" | ||
| 11 | |||
| 12 | #include <mutex> | ||
| 13 | |||
| 14 | #include <openssl/bio.h> | ||
| 15 | #include <openssl/err.h> | ||
| 16 | #include <openssl/ssl.h> | ||
| 17 | #include <openssl/x509.h> | ||
| 18 | |||
| 19 | using namespace Common::FS; | ||
| 20 | |||
| 21 | namespace Service::SSL { | ||
| 22 | |||
| 23 | // Import OpenSSL's `SSL` type into the namespace. This is needed because the | ||
| 24 | // namespace is also named `SSL`. | ||
| 25 | using ::SSL; | ||
| 26 | |||
| 27 | namespace { | ||
| 28 | |||
| 29 | std::once_flag one_time_init_flag; | ||
| 30 | bool one_time_init_success = false; | ||
| 31 | |||
| 32 | SSL_CTX* ssl_ctx; | ||
| 33 | IOFile key_log_file; // only open if SSLKEYLOGFILE set in environment | ||
| 34 | BIO_METHOD* bio_meth; | ||
| 35 | |||
| 36 | Result CheckOpenSSLErrors(); | ||
| 37 | void OneTimeInit(); | ||
| 38 | void OneTimeInitLogFile(); | ||
| 39 | bool OneTimeInitBIO(); | ||
| 40 | |||
| 41 | } // namespace | ||
| 42 | |||
| 43 | class SSLConnectionBackendOpenSSL final : public SSLConnectionBackend { | ||
| 44 | public: | ||
| 45 | Result Init() { | ||
| 46 | std::call_once(one_time_init_flag, OneTimeInit); | ||
| 47 | |||
| 48 | if (!one_time_init_success) { | ||
| 49 | LOG_ERROR(Service_SSL, | ||
| 50 | "Can't create SSL connection because OpenSSL one-time initialization failed"); | ||
| 51 | return ResultInternalError; | ||
| 52 | } | ||
| 53 | |||
| 54 | ssl_ = SSL_new(ssl_ctx); | ||
| 55 | if (!ssl_) { | ||
| 56 | LOG_ERROR(Service_SSL, "SSL_new failed"); | ||
| 57 | return CheckOpenSSLErrors(); | ||
| 58 | } | ||
| 59 | |||
| 60 | SSL_set_connect_state(ssl_); | ||
| 61 | |||
| 62 | bio_ = BIO_new(bio_meth); | ||
| 63 | if (!bio_) { | ||
| 64 | LOG_ERROR(Service_SSL, "BIO_new failed"); | ||
| 65 | return CheckOpenSSLErrors(); | ||
| 66 | } | ||
| 67 | |||
| 68 | BIO_set_data(bio_, this); | ||
| 69 | BIO_set_init(bio_, 1); | ||
| 70 | SSL_set_bio(ssl_, bio_, bio_); | ||
| 71 | |||
| 72 | return ResultSuccess; | ||
| 73 | } | ||
| 74 | |||
| 75 | void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { | ||
| 76 | socket_ = socket; | ||
| 77 | } | ||
| 78 | |||
| 79 | Result SetHostName(const std::string& hostname) override { | ||
| 80 | if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification | ||
| 81 | LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); | ||
| 82 | return CheckOpenSSLErrors(); | ||
| 83 | } | ||
| 84 | if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI | ||
| 85 | LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); | ||
| 86 | return CheckOpenSSLErrors(); | ||
| 87 | } | ||
| 88 | return ResultSuccess; | ||
| 89 | } | ||
| 90 | |||
| 91 | Result DoHandshake() override { | ||
| 92 | SSL_set_verify_result(ssl_, X509_V_OK); | ||
| 93 | int ret = SSL_do_handshake(ssl_); | ||
| 94 | long verify_result = SSL_get_verify_result(ssl_); | ||
| 95 | if (verify_result != X509_V_OK) { | ||
| 96 | LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", | ||
| 97 | X509_verify_cert_error_string(verify_result)); | ||
| 98 | return CheckOpenSSLErrors(); | ||
| 99 | } | ||
| 100 | if (ret <= 0) { | ||
| 101 | int ssl_err = SSL_get_error(ssl_, ret); | ||
| 102 | if (ssl_err == SSL_ERROR_ZERO_RETURN || | ||
| 103 | (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) { | ||
| 104 | LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||
| 105 | return ResultInternalError; | ||
| 106 | } | ||
| 107 | } | ||
| 108 | return HandleReturn("SSL_do_handshake", 0, ret).Code(); | ||
| 109 | } | ||
| 110 | |||
| 111 | ResultVal<size_t> Read(std::span<u8> data) override { | ||
| 112 | size_t actual; | ||
| 113 | int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual); | ||
| 114 | return HandleReturn("SSL_read_ex", actual, ret); | ||
| 115 | } | ||
| 116 | |||
| 117 | ResultVal<size_t> Write(std::span<const u8> data) override { | ||
| 118 | size_t actual; | ||
| 119 | int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual); | ||
| 120 | return HandleReturn("SSL_write_ex", actual, ret); | ||
| 121 | } | ||
| 122 | |||
| 123 | ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { | ||
| 124 | int ssl_err = SSL_get_error(ssl_, ret); | ||
| 125 | CheckOpenSSLErrors(); | ||
| 126 | switch (ssl_err) { | ||
| 127 | case SSL_ERROR_NONE: | ||
| 128 | return actual; | ||
| 129 | case SSL_ERROR_ZERO_RETURN: | ||
| 130 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_ZERO_RETURN", what); | ||
| 131 | // DoHandshake special-cases this, but for Read and Write: | ||
| 132 | return size_t(0); | ||
| 133 | case SSL_ERROR_WANT_READ: | ||
| 134 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_READ", what); | ||
| 135 | return ResultWouldBlock; | ||
| 136 | case SSL_ERROR_WANT_WRITE: | ||
| 137 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); | ||
| 138 | return ResultWouldBlock; | ||
| 139 | default: | ||
| 140 | if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) { | ||
| 141 | LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); | ||
| 142 | return size_t(0); | ||
| 143 | } | ||
| 144 | LOG_ERROR(Service_SSL, "{} => other SSL_get_error return value {}", what, ssl_err); | ||
| 145 | return ResultInternalError; | ||
| 146 | } | ||
| 147 | } | ||
| 148 | |||
| 149 | ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||
| 150 | STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); | ||
| 151 | if (!chain) { | ||
| 152 | LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); | ||
| 153 | return ResultInternalError; | ||
| 154 | } | ||
| 155 | std::vector<std::vector<u8>> ret; | ||
| 156 | int count = sk_X509_num(chain); | ||
| 157 | ASSERT(count >= 0); | ||
| 158 | for (int i = 0; i < count; i++) { | ||
| 159 | X509* x509 = sk_X509_value(chain, i); | ||
| 160 | ASSERT_OR_EXECUTE(x509 != nullptr, { continue; }); | ||
| 161 | unsigned char* buf = nullptr; | ||
| 162 | int len = i2d_X509(x509, &buf); | ||
| 163 | ASSERT_OR_EXECUTE(len >= 0 && buf, { continue; }); | ||
| 164 | ret.emplace_back(buf, buf + len); | ||
| 165 | OPENSSL_free(buf); | ||
| 166 | } | ||
| 167 | return ret; | ||
| 168 | } | ||
| 169 | |||
| 170 | ~SSLConnectionBackendOpenSSL() { | ||
| 171 | // these are null-tolerant: | ||
| 172 | SSL_free(ssl_); | ||
| 173 | BIO_free(bio_); | ||
| 174 | } | ||
| 175 | |||
| 176 | static void KeyLogCallback(const SSL* ssl, const char* line) { | ||
| 177 | std::string str(line); | ||
| 178 | str.push_back('\n'); | ||
| 179 | // Do this in a single WriteString for atomicity if multiple instances | ||
| 180 | // are running on different threads (though that can't currently | ||
| 181 | // happen). | ||
| 182 | if (key_log_file.WriteString(str) != str.size() || !key_log_file.Flush()) { | ||
| 183 | LOG_CRITICAL(Service_SSL, "Failed to write to SSLKEYLOGFILE"); | ||
| 184 | } | ||
| 185 | LOG_DEBUG(Service_SSL, "Wrote to SSLKEYLOGFILE: {}", line); | ||
| 186 | } | ||
| 187 | |||
| 188 | static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { | ||
| 189 | auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||
| 190 | ASSERT_OR_EXECUTE_MSG( | ||
| 191 | self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket"); | ||
| 192 | BIO_clear_retry_flags(bio); | ||
| 193 | auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0); | ||
| 194 | switch (err) { | ||
| 195 | case Network::Errno::SUCCESS: | ||
| 196 | *actual_p = actual; | ||
| 197 | return 1; | ||
| 198 | case Network::Errno::AGAIN: | ||
| 199 | BIO_set_flags(bio, BIO_FLAGS_WRITE | BIO_FLAGS_SHOULD_RETRY); | ||
| 200 | return 0; | ||
| 201 | default: | ||
| 202 | LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||
| 203 | return -1; | ||
| 204 | } | ||
| 205 | } | ||
| 206 | |||
| 207 | static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { | ||
| 208 | auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); | ||
| 209 | ASSERT_OR_EXECUTE_MSG( | ||
| 210 | self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket"); | ||
| 211 | BIO_clear_retry_flags(bio); | ||
| 212 | auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len}); | ||
| 213 | switch (err) { | ||
| 214 | case Network::Errno::SUCCESS: | ||
| 215 | *actual_p = actual; | ||
| 216 | if (actual == 0) { | ||
| 217 | self->got_read_eof_ = true; | ||
| 218 | } | ||
| 219 | return actual ? 1 : 0; | ||
| 220 | case Network::Errno::AGAIN: | ||
| 221 | BIO_set_flags(bio, BIO_FLAGS_READ | BIO_FLAGS_SHOULD_RETRY); | ||
| 222 | return 0; | ||
| 223 | default: | ||
| 224 | LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||
| 225 | return -1; | ||
| 226 | } | ||
| 227 | } | ||
| 228 | |||
| 229 | static long CtrlCallback(BIO* bio, int cmd, long larg, void* parg) { | ||
| 230 | switch (cmd) { | ||
| 231 | case BIO_CTRL_FLUSH: | ||
| 232 | // Nothing to flush. | ||
| 233 | return 1; | ||
| 234 | case BIO_CTRL_PUSH: | ||
| 235 | case BIO_CTRL_POP: | ||
| 236 | case BIO_CTRL_GET_KTLS_SEND: | ||
| 237 | case BIO_CTRL_GET_KTLS_RECV: | ||
| 238 | // We don't support these operations, but don't bother logging them | ||
| 239 | // as they're nothing unusual. | ||
| 240 | return 0; | ||
| 241 | default: | ||
| 242 | LOG_DEBUG(Service_SSL, "OpenSSL BIO got ctrl({}, {}, {})", cmd, larg, parg); | ||
| 243 | return 0; | ||
| 244 | } | ||
| 245 | } | ||
| 246 | |||
| 247 | SSL* ssl_ = nullptr; | ||
| 248 | BIO* bio_ = nullptr; | ||
| 249 | bool got_read_eof_ = false; | ||
| 250 | |||
| 251 | std::shared_ptr<Network::SocketBase> socket_; | ||
| 252 | }; | ||
| 253 | |||
| 254 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 255 | auto conn = std::make_unique<SSLConnectionBackendOpenSSL>(); | ||
| 256 | Result res = conn->Init(); | ||
| 257 | if (res.IsFailure()) { | ||
| 258 | return res; | ||
| 259 | } | ||
| 260 | return conn; | ||
| 261 | } | ||
| 262 | |||
| 263 | namespace { | ||
| 264 | |||
| 265 | Result CheckOpenSSLErrors() { | ||
| 266 | unsigned long rc; | ||
| 267 | const char* file; | ||
| 268 | int line; | ||
| 269 | const char* func; | ||
| 270 | const char* data; | ||
| 271 | int flags; | ||
| 272 | while ((rc = ERR_get_error_all(&file, &line, &func, &data, &flags))) { | ||
| 273 | std::string msg; | ||
| 274 | msg.resize(1024, '\0'); | ||
| 275 | ERR_error_string_n(rc, msg.data(), msg.size()); | ||
| 276 | msg.resize(strlen(msg.data()), '\0'); | ||
| 277 | if (flags & ERR_TXT_STRING) { | ||
| 278 | msg.append(" | "); | ||
| 279 | msg.append(data); | ||
| 280 | } | ||
| 281 | Common::Log::FmtLogMessage(Common::Log::Class::Service_SSL, Common::Log::Level::Error, | ||
| 282 | Common::Log::TrimSourcePath(file), line, func, "OpenSSL: {}", | ||
| 283 | msg); | ||
| 284 | } | ||
| 285 | return ResultInternalError; | ||
| 286 | } | ||
| 287 | |||
| 288 | void OneTimeInit() { | ||
| 289 | ssl_ctx = SSL_CTX_new(TLS_client_method()); | ||
| 290 | if (!ssl_ctx) { | ||
| 291 | LOG_ERROR(Service_SSL, "SSL_CTX_new failed"); | ||
| 292 | CheckOpenSSLErrors(); | ||
| 293 | return; | ||
| 294 | } | ||
| 295 | |||
| 296 | SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); | ||
| 297 | |||
| 298 | if (!SSL_CTX_set_default_verify_paths(ssl_ctx)) { | ||
| 299 | LOG_ERROR(Service_SSL, "SSL_CTX_set_default_verify_paths failed"); | ||
| 300 | CheckOpenSSLErrors(); | ||
| 301 | return; | ||
| 302 | } | ||
| 303 | |||
| 304 | OneTimeInitLogFile(); | ||
| 305 | |||
| 306 | if (!OneTimeInitBIO()) { | ||
| 307 | return; | ||
| 308 | } | ||
| 309 | |||
| 310 | one_time_init_success = true; | ||
| 311 | } | ||
| 312 | |||
| 313 | void OneTimeInitLogFile() { | ||
| 314 | const char* logfile = getenv("SSLKEYLOGFILE"); | ||
| 315 | if (logfile) { | ||
| 316 | key_log_file.Open(logfile, FileAccessMode::Append, FileType::TextFile, | ||
| 317 | FileShareFlag::ShareWriteOnly); | ||
| 318 | if (key_log_file.IsOpen()) { | ||
| 319 | SSL_CTX_set_keylog_callback(ssl_ctx, &SSLConnectionBackendOpenSSL::KeyLogCallback); | ||
| 320 | } else { | ||
| 321 | LOG_CRITICAL(Service_SSL, | ||
| 322 | "SSLKEYLOGFILE was set but file could not be opened; not logging keys!"); | ||
| 323 | } | ||
| 324 | } | ||
| 325 | } | ||
| 326 | |||
| 327 | bool OneTimeInitBIO() { | ||
| 328 | bio_meth = | ||
| 329 | BIO_meth_new(BIO_get_new_index() | BIO_TYPE_SOURCE_SINK, "SSLConnectionBackendOpenSSL"); | ||
| 330 | if (!bio_meth || | ||
| 331 | !BIO_meth_set_write_ex(bio_meth, &SSLConnectionBackendOpenSSL::WriteCallback) || | ||
| 332 | !BIO_meth_set_read_ex(bio_meth, &SSLConnectionBackendOpenSSL::ReadCallback) || | ||
| 333 | !BIO_meth_set_ctrl(bio_meth, &SSLConnectionBackendOpenSSL::CtrlCallback)) { | ||
| 334 | LOG_ERROR(Service_SSL, "Failed to create BIO_METHOD"); | ||
| 335 | return false; | ||
| 336 | } | ||
| 337 | return true; | ||
| 338 | } | ||
| 339 | |||
| 340 | } // namespace | ||
| 341 | |||
| 342 | } // namespace Service::SSL | ||
diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp new file mode 100644 index 000000000..0a326b536 --- /dev/null +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp | |||
| @@ -0,0 +1,529 @@ | |||
| 1 | // SPDX-FileCopyrightText: Copyright 2023 yuzu Emulator Project | ||
| 2 | // SPDX-License-Identifier: GPL-2.0-or-later | ||
| 3 | |||
| 4 | #include "core/hle/service/ssl/ssl_backend.h" | ||
| 5 | #include "core/internal_network/network.h" | ||
| 6 | #include "core/internal_network/sockets.h" | ||
| 7 | |||
| 8 | #include "common/error.h" | ||
| 9 | #include "common/fs/file.h" | ||
| 10 | #include "common/hex_util.h" | ||
| 11 | #include "common/string_util.h" | ||
| 12 | |||
| 13 | #include <mutex> | ||
| 14 | |||
| 15 | #define SECURITY_WIN32 | ||
| 16 | #include <Security.h> | ||
| 17 | #include <schnlsp.h> | ||
| 18 | |||
| 19 | namespace { | ||
| 20 | |||
| 21 | std::once_flag one_time_init_flag; | ||
| 22 | bool one_time_init_success = false; | ||
| 23 | |||
| 24 | SCHANNEL_CRED schannel_cred{ | ||
| 25 | .dwVersion = SCHANNEL_CRED_VERSION, | ||
| 26 | .dwFlags = SCH_USE_STRONG_CRYPTO | // don't allow insecure protocols | ||
| 27 | SCH_CRED_AUTO_CRED_VALIDATION | // validate certs | ||
| 28 | SCH_CRED_NO_DEFAULT_CREDS, // don't automatically present a client certificate | ||
| 29 | // ^ I'm assuming that nobody would want to connect Yuzu to a | ||
| 30 | // service that requires some OS-provided corporate client | ||
| 31 | // certificate, and presenting one to some arbitrary server | ||
| 32 | // might be a privacy concern? Who knows, though. | ||
| 33 | }; | ||
| 34 | |||
| 35 | CredHandle cred_handle; | ||
| 36 | |||
| 37 | static void OneTimeInit() { | ||
| 38 | SECURITY_STATUS ret = | ||
| 39 | AcquireCredentialsHandle(nullptr, const_cast<LPTSTR>(UNISP_NAME), SECPKG_CRED_OUTBOUND, | ||
| 40 | nullptr, &schannel_cred, nullptr, nullptr, &cred_handle, nullptr); | ||
| 41 | if (ret != SEC_E_OK) { | ||
| 42 | // SECURITY_STATUS codes are a type of HRESULT and can be used with NativeErrorToString. | ||
| 43 | LOG_ERROR(Service_SSL, "AcquireCredentialsHandle failed: {}", | ||
| 44 | Common::NativeErrorToString(ret)); | ||
| 45 | return; | ||
| 46 | } | ||
| 47 | |||
| 48 | one_time_init_success = true; | ||
| 49 | } | ||
| 50 | |||
| 51 | } // namespace | ||
| 52 | |||
| 53 | namespace Service::SSL { | ||
| 54 | |||
| 55 | class SSLConnectionBackendSchannel final : public SSLConnectionBackend { | ||
| 56 | public: | ||
| 57 | Result Init() { | ||
| 58 | std::call_once(one_time_init_flag, OneTimeInit); | ||
| 59 | |||
| 60 | if (!one_time_init_success) { | ||
| 61 | LOG_ERROR( | ||
| 62 | Service_SSL, | ||
| 63 | "Can't create SSL connection because Schannel one-time initialization failed"); | ||
| 64 | return ResultInternalError; | ||
| 65 | } | ||
| 66 | |||
| 67 | return ResultSuccess; | ||
| 68 | } | ||
| 69 | |||
| 70 | void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { | ||
| 71 | socket_ = socket; | ||
| 72 | } | ||
| 73 | |||
| 74 | Result SetHostName(const std::string& hostname) override { | ||
| 75 | hostname_ = hostname; | ||
| 76 | return ResultSuccess; | ||
| 77 | } | ||
| 78 | |||
| 79 | Result DoHandshake() override { | ||
| 80 | while (1) { | ||
| 81 | Result r; | ||
| 82 | switch (handshake_state_) { | ||
| 83 | case HandshakeState::Initial: | ||
| 84 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||
| 85 | (r = CallInitializeSecurityContext()) != ResultSuccess) { | ||
| 86 | return r; | ||
| 87 | } | ||
| 88 | // CallInitializeSecurityContext updated `handshake_state_`. | ||
| 89 | continue; | ||
| 90 | case HandshakeState::ContinueNeeded: | ||
| 91 | case HandshakeState::IncompleteMessage: | ||
| 92 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || | ||
| 93 | (r = FillCiphertextReadBuf()) != ResultSuccess) { | ||
| 94 | return r; | ||
| 95 | } | ||
| 96 | if (ciphertext_read_buf_.empty()) { | ||
| 97 | LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); | ||
| 98 | return ResultInternalError; | ||
| 99 | } | ||
| 100 | if ((r = CallInitializeSecurityContext()) != ResultSuccess) { | ||
| 101 | return r; | ||
| 102 | } | ||
| 103 | // CallInitializeSecurityContext updated `handshake_state_`. | ||
| 104 | continue; | ||
| 105 | case HandshakeState::DoneAfterFlush: | ||
| 106 | if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { | ||
| 107 | return r; | ||
| 108 | } | ||
| 109 | handshake_state_ = HandshakeState::Connected; | ||
| 110 | return ResultSuccess; | ||
| 111 | case HandshakeState::Connected: | ||
| 112 | LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); | ||
| 113 | return ResultInternalError; | ||
| 114 | case HandshakeState::Error: | ||
| 115 | return ResultInternalError; | ||
| 116 | } | ||
| 117 | } | ||
| 118 | } | ||
| 119 | |||
| 120 | Result FillCiphertextReadBuf() { | ||
| 121 | size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096; | ||
| 122 | read_buf_fill_size_ = 0; | ||
| 123 | // This unnecessarily zeroes the buffer; oh well. | ||
| 124 | size_t offset = ciphertext_read_buf_.size(); | ||
| 125 | ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); | ||
| 126 | ciphertext_read_buf_.resize(offset + fill_size, 0); | ||
| 127 | auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size); | ||
| 128 | auto [actual, err] = socket_->Recv(0, read_span); | ||
| 129 | switch (err) { | ||
| 130 | case Network::Errno::SUCCESS: | ||
| 131 | ASSERT(static_cast<size_t>(actual) <= fill_size); | ||
| 132 | ciphertext_read_buf_.resize(offset + actual); | ||
| 133 | return ResultSuccess; | ||
| 134 | case Network::Errno::AGAIN: | ||
| 135 | ciphertext_read_buf_.resize(offset); | ||
| 136 | return ResultWouldBlock; | ||
| 137 | default: | ||
| 138 | ciphertext_read_buf_.resize(offset); | ||
| 139 | LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); | ||
| 140 | return ResultInternalError; | ||
| 141 | } | ||
| 142 | } | ||
| 143 | |||
| 144 | // Returns success if the write buffer has been completely emptied. | ||
| 145 | Result FlushCiphertextWriteBuf() { | ||
| 146 | while (!ciphertext_write_buf_.empty()) { | ||
| 147 | auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0); | ||
| 148 | switch (err) { | ||
| 149 | case Network::Errno::SUCCESS: | ||
| 150 | ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size()); | ||
| 151 | ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(), | ||
| 152 | ciphertext_write_buf_.begin() + actual); | ||
| 153 | break; | ||
| 154 | case Network::Errno::AGAIN: | ||
| 155 | return ResultWouldBlock; | ||
| 156 | default: | ||
| 157 | LOG_ERROR(Service_SSL, "Socket send returned Network::Errno {}", err); | ||
| 158 | return ResultInternalError; | ||
| 159 | } | ||
| 160 | } | ||
| 161 | return ResultSuccess; | ||
| 162 | } | ||
| 163 | |||
| 164 | Result CallInitializeSecurityContext() { | ||
| 165 | unsigned long req = ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY | | ||
| 166 | ISC_REQ_REPLAY_DETECT | ISC_REQ_SEQUENCE_DETECT | ISC_REQ_STREAM | | ||
| 167 | ISC_REQ_USE_SUPPLIED_CREDS; | ||
| 168 | unsigned long attr; | ||
| 169 | // https://learn.microsoft.com/en-us/windows/win32/secauthn/initializesecuritycontext--schannel | ||
| 170 | std::array<SecBuffer, 2> input_buffers{{ | ||
| 171 | // only used if `initial_call_done` | ||
| 172 | { | ||
| 173 | // [0] | ||
| 174 | .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), | ||
| 175 | .BufferType = SECBUFFER_TOKEN, | ||
| 176 | .pvBuffer = ciphertext_read_buf_.data(), | ||
| 177 | }, | ||
| 178 | { | ||
| 179 | // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is | ||
| 180 | // returned, or SECBUFFER_EXTRA when SEC_E_CONTINUE_NEEDED is returned if the | ||
| 181 | // whole buffer wasn't used) | ||
| 182 | .BufferType = SECBUFFER_EMPTY, | ||
| 183 | }, | ||
| 184 | }}; | ||
| 185 | std::array<SecBuffer, 2> output_buffers{{ | ||
| 186 | { | ||
| 187 | .BufferType = SECBUFFER_TOKEN, | ||
| 188 | }, // [0] | ||
| 189 | { | ||
| 190 | .BufferType = SECBUFFER_ALERT, | ||
| 191 | }, // [1] | ||
| 192 | }}; | ||
| 193 | SecBufferDesc input_desc{ | ||
| 194 | .ulVersion = SECBUFFER_VERSION, | ||
| 195 | .cBuffers = static_cast<unsigned long>(input_buffers.size()), | ||
| 196 | .pBuffers = input_buffers.data(), | ||
| 197 | }; | ||
| 198 | SecBufferDesc output_desc{ | ||
| 199 | .ulVersion = SECBUFFER_VERSION, | ||
| 200 | .cBuffers = static_cast<unsigned long>(output_buffers.size()), | ||
| 201 | .pBuffers = output_buffers.data(), | ||
| 202 | }; | ||
| 203 | ASSERT_OR_EXECUTE_MSG( | ||
| 204 | input_buffers[0].cbBuffer == ciphertext_read_buf_.size(), | ||
| 205 | { return ResultInternalError; }, "read buffer too large"); | ||
| 206 | |||
| 207 | bool initial_call_done = handshake_state_ != HandshakeState::Initial; | ||
| 208 | if (initial_call_done) { | ||
| 209 | LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", | ||
| 210 | ciphertext_read_buf_.size()); | ||
| 211 | } | ||
| 212 | |||
| 213 | SECURITY_STATUS ret = | ||
| 214 | InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr, | ||
| 215 | // Caller ensured we have set a hostname: | ||
| 216 | const_cast<char*>(hostname_.value().c_str()), req, | ||
| 217 | 0, // Reserved1 | ||
| 218 | 0, // TargetDataRep not used with Schannel | ||
| 219 | initial_call_done ? &input_desc : nullptr, | ||
| 220 | 0, // Reserved2 | ||
| 221 | initial_call_done ? nullptr : &ctxt_, &output_desc, &attr, | ||
| 222 | nullptr); // ptsExpiry | ||
| 223 | |||
| 224 | if (output_buffers[0].pvBuffer) { | ||
| 225 | std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), | ||
| 226 | output_buffers[0].cbBuffer); | ||
| 227 | ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end()); | ||
| 228 | FreeContextBuffer(output_buffers[0].pvBuffer); | ||
| 229 | } | ||
| 230 | |||
| 231 | if (output_buffers[1].pvBuffer) { | ||
| 232 | std::span span(static_cast<u8*>(output_buffers[1].pvBuffer), | ||
| 233 | output_buffers[1].cbBuffer); | ||
| 234 | // The documentation doesn't explain what format this data is in. | ||
| 235 | LOG_DEBUG(Service_SSL, "Got a {}-byte alert buffer: {}", span.size(), | ||
| 236 | Common::HexToString(span)); | ||
| 237 | } | ||
| 238 | |||
| 239 | switch (ret) { | ||
| 240 | case SEC_I_CONTINUE_NEEDED: | ||
| 241 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); | ||
| 242 | if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { | ||
| 243 | LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); | ||
| 244 | ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size()); | ||
| 245 | ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), | ||
| 246 | ciphertext_read_buf_.end() - input_buffers[1].cbBuffer); | ||
| 247 | } else { | ||
| 248 | ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); | ||
| 249 | ciphertext_read_buf_.clear(); | ||
| 250 | } | ||
| 251 | handshake_state_ = HandshakeState::ContinueNeeded; | ||
| 252 | return ResultSuccess; | ||
| 253 | case SEC_E_INCOMPLETE_MESSAGE: | ||
| 254 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); | ||
| 255 | ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); | ||
| 256 | read_buf_fill_size_ = input_buffers[1].cbBuffer; | ||
| 257 | handshake_state_ = HandshakeState::IncompleteMessage; | ||
| 258 | return ResultSuccess; | ||
| 259 | case SEC_E_OK: | ||
| 260 | LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); | ||
| 261 | ciphertext_read_buf_.clear(); | ||
| 262 | handshake_state_ = HandshakeState::DoneAfterFlush; | ||
| 263 | return GrabStreamSizes(); | ||
| 264 | default: | ||
| 265 | LOG_ERROR(Service_SSL, | ||
| 266 | "InitializeSecurityContext failed (probably certificate/protocol issue): {}", | ||
| 267 | Common::NativeErrorToString(ret)); | ||
| 268 | handshake_state_ = HandshakeState::Error; | ||
| 269 | return ResultInternalError; | ||
| 270 | } | ||
| 271 | } | ||
| 272 | |||
| 273 | Result GrabStreamSizes() { | ||
| 274 | SECURITY_STATUS ret = | ||
| 275 | QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); | ||
| 276 | if (ret != SEC_E_OK) { | ||
| 277 | LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", | ||
| 278 | Common::NativeErrorToString(ret)); | ||
| 279 | handshake_state_ = HandshakeState::Error; | ||
| 280 | return ResultInternalError; | ||
| 281 | } | ||
| 282 | return ResultSuccess; | ||
| 283 | } | ||
| 284 | |||
| 285 | ResultVal<size_t> Read(std::span<u8> data) override { | ||
| 286 | if (handshake_state_ != HandshakeState::Connected) { | ||
| 287 | LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); | ||
| 288 | return ResultInternalError; | ||
| 289 | } | ||
| 290 | if (data.size() == 0 || got_read_eof_) { | ||
| 291 | return size_t(0); | ||
| 292 | } | ||
| 293 | while (1) { | ||
| 294 | if (!cleartext_read_buf_.empty()) { | ||
| 295 | size_t read_size = std::min(cleartext_read_buf_.size(), data.size()); | ||
| 296 | std::memcpy(data.data(), cleartext_read_buf_.data(), read_size); | ||
| 297 | cleartext_read_buf_.erase(cleartext_read_buf_.begin(), | ||
| 298 | cleartext_read_buf_.begin() + read_size); | ||
| 299 | return read_size; | ||
| 300 | } | ||
| 301 | if (!ciphertext_read_buf_.empty()) { | ||
| 302 | std::array<SecBuffer, 5> buffers{{ | ||
| 303 | { | ||
| 304 | .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), | ||
| 305 | .BufferType = SECBUFFER_DATA, | ||
| 306 | .pvBuffer = ciphertext_read_buf_.data(), | ||
| 307 | }, | ||
| 308 | { | ||
| 309 | .BufferType = SECBUFFER_EMPTY, | ||
| 310 | }, | ||
| 311 | { | ||
| 312 | .BufferType = SECBUFFER_EMPTY, | ||
| 313 | }, | ||
| 314 | { | ||
| 315 | .BufferType = SECBUFFER_EMPTY, | ||
| 316 | }, | ||
| 317 | }}; | ||
| 318 | ASSERT_OR_EXECUTE_MSG( | ||
| 319 | buffers[0].cbBuffer == ciphertext_read_buf_.size(), | ||
| 320 | { return ResultInternalError; }, "read buffer too large"); | ||
| 321 | SecBufferDesc desc{ | ||
| 322 | .ulVersion = SECBUFFER_VERSION, | ||
| 323 | .cBuffers = static_cast<unsigned long>(buffers.size()), | ||
| 324 | .pBuffers = buffers.data(), | ||
| 325 | }; | ||
| 326 | SECURITY_STATUS ret = | ||
| 327 | DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); | ||
| 328 | switch (ret) { | ||
| 329 | case SEC_E_OK: | ||
| 330 | ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, | ||
| 331 | { return ResultInternalError; }); | ||
| 332 | ASSERT_OR_EXECUTE(buffers[1].BufferType == SECBUFFER_DATA, | ||
| 333 | { return ResultInternalError; }); | ||
| 334 | ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, | ||
| 335 | { return ResultInternalError; }); | ||
| 336 | cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer), | ||
| 337 | static_cast<u8*>(buffers[1].pvBuffer) + | ||
| 338 | buffers[1].cbBuffer); | ||
| 339 | if (buffers[3].BufferType == SECBUFFER_EXTRA) { | ||
| 340 | ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size()); | ||
| 341 | ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), | ||
| 342 | ciphertext_read_buf_.end() - | ||
| 343 | buffers[3].cbBuffer); | ||
| 344 | } else { | ||
| 345 | ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); | ||
| 346 | ciphertext_read_buf_.clear(); | ||
| 347 | } | ||
| 348 | continue; | ||
| 349 | case SEC_E_INCOMPLETE_MESSAGE: | ||
| 350 | break; | ||
| 351 | case SEC_I_CONTEXT_EXPIRED: | ||
| 352 | // Server hung up by sending close_notify. | ||
| 353 | got_read_eof_ = true; | ||
| 354 | return size_t(0); | ||
| 355 | default: | ||
| 356 | LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", | ||
| 357 | Common::NativeErrorToString(ret)); | ||
| 358 | return ResultInternalError; | ||
| 359 | } | ||
| 360 | } | ||
| 361 | Result r = FillCiphertextReadBuf(); | ||
| 362 | if (r != ResultSuccess) { | ||
| 363 | return r; | ||
| 364 | } | ||
| 365 | if (ciphertext_read_buf_.empty()) { | ||
| 366 | got_read_eof_ = true; | ||
| 367 | return size_t(0); | ||
| 368 | } | ||
| 369 | } | ||
| 370 | } | ||
| 371 | |||
| 372 | ResultVal<size_t> Write(std::span<const u8> data) override { | ||
| 373 | if (handshake_state_ != HandshakeState::Connected) { | ||
| 374 | LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); | ||
| 375 | return ResultInternalError; | ||
| 376 | } | ||
| 377 | if (data.size() == 0) { | ||
| 378 | return size_t(0); | ||
| 379 | } | ||
| 380 | data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage)); | ||
| 381 | if (!cleartext_write_buf_.empty()) { | ||
| 382 | // Already in the middle of a write. It wouldn't make sense to not | ||
| 383 | // finish sending the entire buffer since TLS has | ||
| 384 | // header/MAC/padding/etc. | ||
| 385 | if (data.size() != cleartext_write_buf_.size() || | ||
| 386 | std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) { | ||
| 387 | LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); | ||
| 388 | return ResultInternalError; | ||
| 389 | } | ||
| 390 | return WriteAlreadyEncryptedData(); | ||
| 391 | } else { | ||
| 392 | cleartext_write_buf_.assign(data.begin(), data.end()); | ||
| 393 | } | ||
| 394 | |||
| 395 | std::vector<u8> header_buf(stream_sizes_.cbHeader, 0); | ||
| 396 | std::vector<u8> tmp_data_buf = cleartext_write_buf_; | ||
| 397 | std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0); | ||
| 398 | |||
| 399 | std::array<SecBuffer, 3> buffers{{ | ||
| 400 | { | ||
| 401 | .cbBuffer = stream_sizes_.cbHeader, | ||
| 402 | .BufferType = SECBUFFER_STREAM_HEADER, | ||
| 403 | .pvBuffer = header_buf.data(), | ||
| 404 | }, | ||
| 405 | { | ||
| 406 | .cbBuffer = static_cast<unsigned long>(tmp_data_buf.size()), | ||
| 407 | .BufferType = SECBUFFER_DATA, | ||
| 408 | .pvBuffer = tmp_data_buf.data(), | ||
| 409 | }, | ||
| 410 | { | ||
| 411 | .cbBuffer = stream_sizes_.cbTrailer, | ||
| 412 | .BufferType = SECBUFFER_STREAM_TRAILER, | ||
| 413 | .pvBuffer = trailer_buf.data(), | ||
| 414 | }, | ||
| 415 | }}; | ||
| 416 | ASSERT_OR_EXECUTE_MSG( | ||
| 417 | buffers[1].cbBuffer == tmp_data_buf.size(), { return ResultInternalError; }, | ||
| 418 | "temp buffer too large"); | ||
| 419 | SecBufferDesc desc{ | ||
| 420 | .ulVersion = SECBUFFER_VERSION, | ||
| 421 | .cBuffers = static_cast<unsigned long>(buffers.size()), | ||
| 422 | .pBuffers = buffers.data(), | ||
| 423 | }; | ||
| 424 | |||
| 425 | SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); | ||
| 426 | if (ret != SEC_E_OK) { | ||
| 427 | LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); | ||
| 428 | return ResultInternalError; | ||
| 429 | } | ||
| 430 | ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(), | ||
| 431 | header_buf.end()); | ||
| 432 | ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(), | ||
| 433 | tmp_data_buf.end()); | ||
| 434 | ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(), | ||
| 435 | trailer_buf.end()); | ||
| 436 | return WriteAlreadyEncryptedData(); | ||
| 437 | } | ||
| 438 | |||
| 439 | ResultVal<size_t> WriteAlreadyEncryptedData() { | ||
| 440 | Result r = FlushCiphertextWriteBuf(); | ||
| 441 | if (r != ResultSuccess) { | ||
| 442 | return r; | ||
| 443 | } | ||
| 444 | // write buf is empty | ||
| 445 | size_t cleartext_bytes_written = cleartext_write_buf_.size(); | ||
| 446 | cleartext_write_buf_.clear(); | ||
| 447 | return cleartext_bytes_written; | ||
| 448 | } | ||
| 449 | |||
| 450 | ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { | ||
| 451 | PCCERT_CONTEXT returned_cert = nullptr; | ||
| 452 | SECURITY_STATUS ret = | ||
| 453 | QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); | ||
| 454 | if (ret != SEC_E_OK) { | ||
| 455 | LOG_ERROR(Service_SSL, | ||
| 456 | "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", | ||
| 457 | Common::NativeErrorToString(ret)); | ||
| 458 | return ResultInternalError; | ||
| 459 | } | ||
| 460 | PCCERT_CONTEXT some_cert = nullptr; | ||
| 461 | std::vector<std::vector<u8>> certs; | ||
| 462 | while ((some_cert = CertEnumCertificatesInStore(returned_cert->hCertStore, some_cert))) { | ||
| 463 | certs.emplace_back(static_cast<u8*>(some_cert->pbCertEncoded), | ||
| 464 | static_cast<u8*>(some_cert->pbCertEncoded) + | ||
| 465 | some_cert->cbCertEncoded); | ||
| 466 | } | ||
| 467 | std::reverse(certs.begin(), | ||
| 468 | certs.end()); // Windows returns certs in reverse order from what we want | ||
| 469 | CertFreeCertificateContext(returned_cert); | ||
| 470 | return certs; | ||
| 471 | } | ||
| 472 | |||
| 473 | ~SSLConnectionBackendSchannel() { | ||
| 474 | if (handshake_state_ != HandshakeState::Initial) { | ||
| 475 | DeleteSecurityContext(&ctxt_); | ||
| 476 | } | ||
| 477 | } | ||
| 478 | |||
| 479 | enum class HandshakeState { | ||
| 480 | // Haven't called anything yet. | ||
| 481 | Initial, | ||
| 482 | // `SEC_I_CONTINUE_NEEDED` was returned by | ||
| 483 | // `InitializeSecurityContext`; must finish sending data (if any) in | ||
| 484 | // the write buffer, then read at least one byte before calling | ||
| 485 | // `InitializeSecurityContext` again. | ||
| 486 | ContinueNeeded, | ||
| 487 | // `SEC_E_INCOMPLETE_MESSAGE` was returned by | ||
| 488 | // `InitializeSecurityContext`; hopefully the write buffer is empty; | ||
| 489 | // must read at least one byte before calling | ||
| 490 | // `InitializeSecurityContext` again. | ||
| 491 | IncompleteMessage, | ||
| 492 | // `SEC_E_OK` was returned by `InitializeSecurityContext`; must | ||
| 493 | // finish sending data in the write buffer before having `DoHandshake` | ||
| 494 | // report success. | ||
| 495 | DoneAfterFlush, | ||
| 496 | // We finished the above and are now connected. At this point, writing | ||
| 497 | // and reading are separate 'state machines' represented by the | ||
| 498 | // nonemptiness of the ciphertext and cleartext read and write buffers. | ||
| 499 | Connected, | ||
| 500 | // Another error was returned and we shouldn't allow initialization | ||
| 501 | // to continue. | ||
| 502 | Error, | ||
| 503 | } handshake_state_ = HandshakeState::Initial; | ||
| 504 | |||
| 505 | CtxtHandle ctxt_; | ||
| 506 | SecPkgContext_StreamSizes stream_sizes_; | ||
| 507 | |||
| 508 | std::shared_ptr<Network::SocketBase> socket_; | ||
| 509 | std::optional<std::string> hostname_; | ||
| 510 | |||
| 511 | std::vector<u8> ciphertext_read_buf_; | ||
| 512 | std::vector<u8> ciphertext_write_buf_; | ||
| 513 | std::vector<u8> cleartext_read_buf_; | ||
| 514 | std::vector<u8> cleartext_write_buf_; | ||
| 515 | |||
| 516 | bool got_read_eof_ = false; | ||
| 517 | size_t read_buf_fill_size_ = 0; | ||
| 518 | }; | ||
| 519 | |||
| 520 | ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { | ||
| 521 | auto conn = std::make_unique<SSLConnectionBackendSchannel>(); | ||
| 522 | Result res = conn->Init(); | ||
| 523 | if (res.IsFailure()) { | ||
| 524 | return res; | ||
| 525 | } | ||
| 526 | return conn; | ||
| 527 | } | ||
| 528 | |||
| 529 | } // namespace Service::SSL | ||