summaryrefslogtreecommitdiff
path: root/src/core/internal_network
diff options
context:
space:
mode:
authorGravatar comex2023-06-19 18:17:43 -0700
committerGravatar comex2023-06-25 12:53:31 -0700
commit8e703e08dfcf735a08df2ceff6a05221b7cc981f (patch)
tree771ebe71883ff9e179156f2b38b21b05070d7667 /src/core/internal_network
parentMerge pull request #10825 from 8bitDream/vcpkg-zlib (diff)
downloadyuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.gz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.tar.xz
yuzu-8e703e08dfcf735a08df2ceff6a05221b7cc981f.zip
Implement SSL service
This implements some missing network APIs including a large chunk of the SSL service, enough for Mario Maker (with an appropriate mod applied) to connect to the fan server [Open Course World](https://opencourse.world/). Connecting to first-party servers is out of scope of this PR and is a minefield I'd rather not step into. ## TLS TLS is implemented with multiple backends depending on the system's 'native' TLS library. Currently there are two backends: Schannel for Windows, and OpenSSL for Linux. (In reality Linux is a bit of a free-for-all where there's no one 'native' library, but OpenSSL is the closest it gets.) On macOS the 'native' library is SecureTransport but that isn't implemented in this PR. (Instead, all non-Windows OSes will use OpenSSL unless disabled with `-DENABLE_OPENSSL=OFF`.) Why have multiple backends instead of just using a single library, especially given that Yuzu already embeds mbedtls for cryptographic algorithms? Well, I tried implementing this on mbedtls first, but the problem is TLS policies - mainly trusted certificate policies, and to a lesser extent trusted algorithms, SSL versions, etc. ...In practice, the chance that someone is going to conduct a man-in-the-middle attack on a third-party game server is pretty low, but I'm a security nerd so I like to do the right security things. My base assumption is that we want to use the host system's TLS policies. An alternative would be to more closely emulate the Switch's TLS implementation (which is based on NSS). But for one thing, I don't feel like reverse engineering it. And I'd argue that for third-party servers such as Open Course World, it's theoretically preferable to use the system's policies rather than the Switch's, for two reasons 1. Someday the Switch will stop being updated, and the trusted cert list, algorithms, etc. will start to go stale, but users will still want to connect to third-party servers, and there's no reason they shouldn't have up-to-date security when doing so. At that point, homebrew users on actual hardware may patch the TLS implementation, but for emulators it's simpler to just use the host's stack. 2. Also, it's good to respect any custom certificate policies the user may have added systemwide. For example, they may have added custom trusted CAs in order to use TLS debugging tools or pass through corporate MitM middleboxes. Or they may have removed some CAs that are normally trusted out of paranoia. Note that this policy wouldn't work as-is for connecting to first-party servers, because some of them serve certificates based on Nintendo's own CA rather than a publicly trusted one. However, this could probably be solved easily by using appropriate APIs to adding Nintendo's CA as an alternate trusted cert for Yuzu's connections. That is not implemented in this PR because, again, first-party servers are out of scope. (If anything I'd rather have an option to _block_ connections to Nintendo servers, but that's not implemented here.) To use the host's TLS policies, there are three theoretical options: a) Import the host's trusted certificate list into a cross-platform TLS library (presumably mbedtls). b) Use the native TLS library to verify certificates but use a cross-platform TLS library for everything else. c) Use the native TLS library for everything. Two problems with option a). First, importing the trusted certificate list at minimum requires a bunch of platform-specific code, which mbedtls does not have built in. Interestingly, OpenSSL recently gained the ability to import the Windows certificate trust store... but that leads to the second problem, which is that a list of trusted certificates is [not expressive enough](https://bugs.archlinux.org/task/41909) to express a modern certificate trust policy. For example, Windows has the concept of [explicitly distrusted certificates](https://learn.microsoft.com/en-us/previous-versions/windows/it-pro/windows-server-2012-r2-and-2012/dn265983(v=ws.11)), and macOS requires Certificate Transparency validation for some certificates with complex rules for when it's required. Option b) (using native library just to verify certs) is probably feasible, but it would miss aspects of TLS policy other than trusted certs (like allowed algorithms), and in any case it might well require writing more code, not less, compared to using the native library for everything. So I ended up at option c), using the native library for everything. What I'd *really* prefer would be to use a third-party library that does option c) for me. Rust has a good library for this, [native-tls](https://docs.rs/native-tls/latest/native_tls/). I did search, but I couldn't find a good option in the C or C++ ecosystem, at least not any that wasn't part of some much larger framework. I was surprised - isn't this a pretty common use case? Well, many applications only need TLS for HTTPS, and they can use libcurl, which has a TLS abstraction layer internally but doesn't expose it. Other applications only support a single TLS library, or use one of the aforementioned larger frameworks, or are platform-specific to begin with, or of course are written in a non-C/C++ language, most of which have some canonical choice for TLS. But there are also many applications that have a set of TLS backends just like this; it's just that nobody has gone ahead and abstracted the pattern into a library, at least not a widespread one. Amusingly, there is one TLS abstraction layer that Yuzu already bundles: the one in ffmpeg. But it is missing some features that would be needed to use it here (like reusing an existing socket rather than managing the socket itself). Though, that does mean that the wiki's build instructions for Linux (and macOS for some reason?) already recommend installing OpenSSL, so no need to update those. ## Other APIs implemented - Sockets: - GetSockOpt(`SO_ERROR`) - SetSockOpt(`SO_NOSIGPIPE`) (stub, I have no idea what this does on Switch) - `DuplicateSocket` (because the SSL sysmodule calls it internally) - More `PollEvents` values - NSD: - `Resolve` and `ResolveEx` (stub, good enough for Open Course World and probably most third-party servers, but not first-party) - SFDNSRES: - `GetHostByNameRequest` and `GetHostByNameRequestWithOptions` - `ResolverSetOptionRequest` (stub) ## Fixes - Parts of the socket code were previously allocating a `sockaddr` object on the stack when calling functions that take a `sockaddr*` (e.g. `accept`). This might seem like the right thing to do to avoid illegal aliasing, but in fact `sockaddr` is not guaranteed to be large enough to hold any particular type of address, only the header. This worked in practice because in practice `sockaddr` is the same size as `sockaddr_in`, but it's not how the API is meant to be used. I changed this to allocate an `sockaddr_in` on the stack and `reinterpret_cast` it. I could try to do something cleverer with `aligned_storage`, but casting is the idiomatic way to use these particular APIs, so it's really the system's responsibility to avoid any aliasing issues. - I rewrote most of the `GetAddrInfoRequest[WithOptions]` implementation. The old implementation invoked the host's getaddrinfo directly from sfdnsres.cpp, and directly passed through the host's socket type, protocol, etc. values rather than looking up the corresponding constants on the Switch. To be fair, these constants don't tend to actually vary across systems, but still... I added a wrapper for `getaddrinfo` in `internal_network/network.cpp` similar to the ones for other socket APIs, and changed the `GetAddrInfoRequest` implementation to use it. While I was at it, I rewrote the serialization to use the same approach I used to implement `GetHostByNameRequest`, because it reduces the number of size calculations. While doing so I removed `AF_INET6` support because the Switch doesn't support IPv6; it might be nice to support IPv6 anyway, but that would have to apply to all of the socket APIs. I also corrected the IPC wrappers for `GetAddrInfoRequest` and `GetAddrInfoRequestWithOptions` based on reverse engineering and hardware testing. Every call to `GetAddrInfoRequestWithOptions` returns *four* different error codes (IPC status, getaddrinfo error code, netdb error code, and errno), and `GetAddrInfoRequest` returns three of those but in a different order, and it doesn't really matter but the existing implementation was a bit off, as I discovered while testing `GetHostByNameRequest`. - The new serialization code is based on two simple helper functions: ```cpp template <typename T> static void Append(std::vector<u8>& vec, T t); void AppendNulTerminated(std::vector<u8>& vec, std::string_view str); ``` I was thinking there must be existing functions somewhere that assist with serialization/deserialization of binary data, but all I could find was the helper methods in `IOFile` and `HLERequestContext`, not anything that could be used with a generic byte buffer. If I'm not missing something, then maybe I should move the above functions to a new header in `common`... right now they're just sitting in `sfdnsres.cpp` where they're used. - Not a fix, but `SocketBase::Recv`/`Send` is changed to use `std::span<u8>` rather than `std::vector<u8>&` to avoid needing to copy the data to/from a vector when those methods are called from the TLS implementation.
Diffstat (limited to 'src/core/internal_network')
-rw-r--r--src/core/internal_network/network.cpp274
-rw-r--r--src/core/internal_network/network.h34
-rw-r--r--src/core/internal_network/socket_proxy.cpp22
-rw-r--r--src/core/internal_network/socket_proxy.h8
-rw-r--r--src/core/internal_network/sockets.h16
5 files changed, 281 insertions, 73 deletions
diff --git a/src/core/internal_network/network.cpp b/src/core/internal_network/network.cpp
index 75ac10a9c..39381e06e 100644
--- a/src/core/internal_network/network.cpp
+++ b/src/core/internal_network/network.cpp
@@ -121,6 +121,8 @@ Errno TranslateNativeError(int e) {
121 return Errno::MSGSIZE; 121 return Errno::MSGSIZE;
122 case WSAETIMEDOUT: 122 case WSAETIMEDOUT:
123 return Errno::TIMEDOUT; 123 return Errno::TIMEDOUT;
124 case WSAEINPROGRESS:
125 return Errno::INPROGRESS;
124 default: 126 default:
125 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 127 UNIMPLEMENTED_MSG("Unimplemented errno={}", e);
126 return Errno::OTHER; 128 return Errno::OTHER;
@@ -195,6 +197,8 @@ bool EnableNonBlock(int fd, bool enable) {
195 197
196Errno TranslateNativeError(int e) { 198Errno TranslateNativeError(int e) {
197 switch (e) { 199 switch (e) {
200 case 0:
201 return Errno::SUCCESS;
198 case EBADF: 202 case EBADF:
199 return Errno::BADF; 203 return Errno::BADF;
200 case EINVAL: 204 case EINVAL:
@@ -219,8 +223,10 @@ Errno TranslateNativeError(int e) {
219 return Errno::MSGSIZE; 223 return Errno::MSGSIZE;
220 case ETIMEDOUT: 224 case ETIMEDOUT:
221 return Errno::TIMEDOUT; 225 return Errno::TIMEDOUT;
226 case EINPROGRESS:
227 return Errno::INPROGRESS;
222 default: 228 default:
223 UNIMPLEMENTED_MSG("Unimplemented errno={}", e); 229 UNIMPLEMENTED_MSG("Unimplemented errno={} ({})", e, strerror(e));
224 return Errno::OTHER; 230 return Errno::OTHER;
225 } 231 }
226} 232}
@@ -234,15 +240,84 @@ Errno GetAndLogLastError() {
234 int e = errno; 240 int e = errno;
235#endif 241#endif
236 const Errno err = TranslateNativeError(e); 242 const Errno err = TranslateNativeError(e);
237 if (err == Errno::AGAIN || err == Errno::TIMEDOUT) { 243 if (err == Errno::AGAIN || err == Errno::TIMEDOUT || err == Errno::INPROGRESS) {
244 // These happen during normal operation, so only log them at debug level.
245 LOG_DEBUG(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
238 return err; 246 return err;
239 } 247 }
240 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e)); 248 LOG_ERROR(Network, "Socket operation error: {}", Common::NativeErrorToString(e));
241 return err; 249 return err;
242} 250}
243 251
244int TranslateDomain(Domain domain) { 252GetAddrInfoError TranslateGetAddrInfoErrorFromNative(int gai_err) {
253 switch (gai_err) {
254 case 0:
255 return GetAddrInfoError::SUCCESS;
256#ifdef EAI_ADDRFAMILY
257 case EAI_ADDRFAMILY:
258 return GetAddrInfoError::ADDRFAMILY;
259#endif
260 case EAI_AGAIN:
261 return GetAddrInfoError::AGAIN;
262 case EAI_BADFLAGS:
263 return GetAddrInfoError::BADFLAGS;
264 case EAI_FAIL:
265 return GetAddrInfoError::FAIL;
266 case EAI_FAMILY:
267 return GetAddrInfoError::FAMILY;
268 case EAI_MEMORY:
269 return GetAddrInfoError::MEMORY;
270 case EAI_NONAME:
271 return GetAddrInfoError::NONAME;
272 case EAI_SERVICE:
273 return GetAddrInfoError::SERVICE;
274 case EAI_SOCKTYPE:
275 return GetAddrInfoError::SOCKTYPE;
276 // These codes may not be defined on all systems:
277#ifdef EAI_SYSTEM
278 case EAI_SYSTEM:
279 return GetAddrInfoError::SYSTEM;
280#endif
281#ifdef EAI_BADHINTS
282 case EAI_BADHINTS:
283 return GetAddrInfoError::BADHINTS;
284#endif
285#ifdef EAI_PROTOCOL
286 case EAI_PROTOCOL:
287 return GetAddrInfoError::PROTOCOL;
288#endif
289#ifdef EAI_OVERFLOW
290 case EAI_OVERFLOW:
291 return GetAddrInfoError::OVERFLOW_;
292#endif
293 default:
294#ifdef EAI_NODATA
295 // This can't be a case statement because it would create a duplicate
296 // case on Windows where EAI_NODATA is an alias for EAI_NONAME.
297 if (gai_err == EAI_NODATA) {
298 return GetAddrInfoError::NODATA;
299 }
300#endif
301 return GetAddrInfoError::OTHER;
302 }
303}
304
305Domain TranslateDomainFromNative(int domain) {
245 switch (domain) { 306 switch (domain) {
307 case 0:
308 return Domain::Unspecified;
309 case AF_INET:
310 return Domain::INET;
311 default:
312 UNIMPLEMENTED_MSG("Unhandled domain={}", domain);
313 return Domain::INET;
314 }
315}
316
317int TranslateDomainToNative(Domain domain) {
318 switch (domain) {
319 case Domain::Unspecified:
320 return 0;
246 case Domain::INET: 321 case Domain::INET:
247 return AF_INET; 322 return AF_INET;
248 default: 323 default:
@@ -251,20 +326,58 @@ int TranslateDomain(Domain domain) {
251 } 326 }
252} 327}
253 328
254int TranslateType(Type type) { 329Type TranslateTypeFromNative(int type) {
330 switch (type) {
331 case 0:
332 return Type::Unspecified;
333 case SOCK_STREAM:
334 return Type::STREAM;
335 case SOCK_DGRAM:
336 return Type::DGRAM;
337 case SOCK_RAW:
338 return Type::RAW;
339 case SOCK_SEQPACKET:
340 return Type::SEQPACKET;
341 default:
342 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
343 return Type::STREAM;
344 }
345}
346
347int TranslateTypeToNative(Type type) {
255 switch (type) { 348 switch (type) {
349 case Type::Unspecified:
350 return 0;
256 case Type::STREAM: 351 case Type::STREAM:
257 return SOCK_STREAM; 352 return SOCK_STREAM;
258 case Type::DGRAM: 353 case Type::DGRAM:
259 return SOCK_DGRAM; 354 return SOCK_DGRAM;
355 case Type::RAW:
356 return SOCK_RAW;
260 default: 357 default:
261 UNIMPLEMENTED_MSG("Unimplemented type={}", type); 358 UNIMPLEMENTED_MSG("Unimplemented type={}", type);
262 return 0; 359 return 0;
263 } 360 }
264} 361}
265 362
266int TranslateProtocol(Protocol protocol) { 363Protocol TranslateProtocolFromNative(int protocol) {
364 switch (protocol) {
365 case 0:
366 return Protocol::Unspecified;
367 case IPPROTO_TCP:
368 return Protocol::TCP;
369 case IPPROTO_UDP:
370 return Protocol::UDP;
371 default:
372 UNIMPLEMENTED_MSG("Unimplemented protocol={}", protocol);
373 return Protocol::Unspecified;
374 }
375}
376
377int TranslateProtocolToNative(Protocol protocol) {
267 switch (protocol) { 378 switch (protocol) {
379 case Protocol::Unspecified:
380 return 0;
268 case Protocol::TCP: 381 case Protocol::TCP:
269 return IPPROTO_TCP; 382 return IPPROTO_TCP;
270 case Protocol::UDP: 383 case Protocol::UDP:
@@ -275,21 +388,10 @@ int TranslateProtocol(Protocol protocol) {
275 } 388 }
276} 389}
277 390
278SockAddrIn TranslateToSockAddrIn(sockaddr input_) { 391SockAddrIn TranslateToSockAddrIn(sockaddr_in input, size_t input_len) {
279 sockaddr_in input;
280 std::memcpy(&input, &input_, sizeof(input));
281
282 SockAddrIn result; 392 SockAddrIn result;
283 393
284 switch (input.sin_family) { 394 result.family = TranslateDomainFromNative(input.sin_family);
285 case AF_INET:
286 result.family = Domain::INET;
287 break;
288 default:
289 UNIMPLEMENTED_MSG("Unhandled sockaddr family={}", input.sin_family);
290 result.family = Domain::INET;
291 break;
292 }
293 395
294 result.portno = ntohs(input.sin_port); 396 result.portno = ntohs(input.sin_port);
295 397
@@ -301,22 +403,28 @@ SockAddrIn TranslateToSockAddrIn(sockaddr input_) {
301short TranslatePollEvents(PollEvents events) { 403short TranslatePollEvents(PollEvents events) {
302 short result = 0; 404 short result = 0;
303 405
304 if (True(events & PollEvents::In)) { 406 const auto translate = [&result, &events](PollEvents guest, short host) {
305 events &= ~PollEvents::In; 407 if (True(events & guest)) {
306 result |= POLLIN; 408 events &= ~guest;
307 } 409 result |= host;
308 if (True(events & PollEvents::Pri)) { 410 }
309 events &= ~PollEvents::Pri; 411 };
412
413 translate(PollEvents::In, POLLIN);
414 translate(PollEvents::Pri, POLLPRI);
415 translate(PollEvents::Out, POLLOUT);
416 translate(PollEvents::Err, POLLERR);
417 translate(PollEvents::Hup, POLLHUP);
418 translate(PollEvents::Nval, POLLNVAL);
419 translate(PollEvents::RdNorm, POLLRDNORM);
420 translate(PollEvents::RdBand, POLLRDBAND);
421 translate(PollEvents::WrBand, POLLWRBAND);
422
310#ifdef _WIN32 423#ifdef _WIN32
424 if (True(events & PollEvents::Pri)) {
311 LOG_WARNING(Service, "Winsock doesn't support POLLPRI"); 425 LOG_WARNING(Service, "Winsock doesn't support POLLPRI");
312#else
313 result |= POLLPRI;
314#endif
315 }
316 if (True(events & PollEvents::Out)) {
317 events &= ~PollEvents::Out;
318 result |= POLLOUT;
319 } 426 }
427#endif
320 428
321 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events); 429 UNIMPLEMENTED_IF_MSG((u16)events != 0, "Unhandled guest events=0x{:x}", (u16)events);
322 430
@@ -337,6 +445,10 @@ PollEvents TranslatePollRevents(short revents) {
337 translate(POLLOUT, PollEvents::Out); 445 translate(POLLOUT, PollEvents::Out);
338 translate(POLLERR, PollEvents::Err); 446 translate(POLLERR, PollEvents::Err);
339 translate(POLLHUP, PollEvents::Hup); 447 translate(POLLHUP, PollEvents::Hup);
448 translate(POLLNVAL, PollEvents::Nval);
449 translate(POLLRDNORM, PollEvents::RdNorm);
450 translate(POLLRDBAND, PollEvents::RdBand);
451 translate(POLLWRBAND, PollEvents::WrBand);
340 452
341 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents); 453 UNIMPLEMENTED_IF_MSG(revents != 0, "Unhandled host revents=0x{:x}", revents);
342 454
@@ -360,12 +472,53 @@ std::optional<IPv4Address> GetHostIPv4Address() {
360 return {}; 472 return {};
361 } 473 }
362 474
363 std::array<char, 16> ip_addr = {};
364 ASSERT(inet_ntop(AF_INET, &network_interface->ip_address, ip_addr.data(), sizeof(ip_addr)) !=
365 nullptr);
366 return TranslateIPv4(network_interface->ip_address); 475 return TranslateIPv4(network_interface->ip_address);
367} 476}
368 477
478std::string IPv4AddressToString(IPv4Address ip_addr) {
479 std::array<char, INET_ADDRSTRLEN> buf = {};
480 ASSERT(inet_ntop(AF_INET, &ip_addr, buf.data(), sizeof(buf)) == buf.data());
481 return std::string(buf.data());
482}
483
484u32 IPv4AddressToInteger(IPv4Address ip_addr) {
485 return static_cast<u32>(ip_addr[0]) << 24 | static_cast<u32>(ip_addr[1]) << 16 |
486 static_cast<u32>(ip_addr[2]) << 8 | static_cast<u32>(ip_addr[3]);
487}
488
489#undef GetAddrInfo // Windows defines it as a macro
490
491Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddrInfo(
492 const std::string& host, const std::optional<std::string>& service) {
493 addrinfo hints{};
494 hints.ai_family = AF_INET; // Switch only supports IPv4.
495 addrinfo* addrinfo;
496 s32 gai_err = getaddrinfo(host.c_str(), service.has_value() ? service->c_str() : nullptr,
497 &hints, &addrinfo);
498 if (gai_err != 0) {
499 return Common::Unexpected(TranslateGetAddrInfoErrorFromNative(gai_err));
500 }
501 std::vector<AddrInfo> ret;
502 for (auto* current = addrinfo; current; current = current->ai_next) {
503 // We should only get AF_INET results due to the hints value.
504 ASSERT_OR_EXECUTE(addrinfo->ai_family == AF_INET &&
505 addrinfo->ai_addrlen == sizeof(sockaddr_in),
506 continue;);
507
508 AddrInfo& out = ret.emplace_back();
509 out.family = TranslateDomainFromNative(current->ai_family);
510 out.socket_type = TranslateTypeFromNative(current->ai_socktype);
511 out.protocol = TranslateProtocolFromNative(current->ai_protocol);
512 out.addr = TranslateToSockAddrIn(*reinterpret_cast<sockaddr_in*>(current->ai_addr),
513 current->ai_addrlen);
514 if (current->ai_canonname != nullptr) {
515 out.canon_name = current->ai_canonname;
516 }
517 }
518 freeaddrinfo(addrinfo);
519 return ret;
520}
521
369std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) { 522std::pair<s32, Errno> Poll(std::vector<PollFD>& pollfds, s32 timeout) {
370 const size_t num = pollfds.size(); 523 const size_t num = pollfds.size();
371 524
@@ -411,6 +564,18 @@ Socket::Socket(Socket&& rhs) noexcept {
411} 564}
412 565
413template <typename T> 566template <typename T>
567std::pair<T, Errno> Socket::GetSockOpt(SOCKET fd_, int option) {
568 T value{};
569 socklen_t len = sizeof(value);
570 const int result = getsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<char*>(&value), &len);
571 if (result != SOCKET_ERROR) {
572 ASSERT(len == sizeof(value));
573 return {value, Errno::SUCCESS};
574 }
575 return {value, GetAndLogLastError()};
576}
577
578template <typename T>
414Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) { 579Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
415 const int result = 580 const int result =
416 setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value)); 581 setsockopt(fd_, SOL_SOCKET, option, reinterpret_cast<const char*>(&value), sizeof(value));
@@ -421,7 +586,8 @@ Errno Socket::SetSockOpt(SOCKET fd_, int option, T value) {
421} 586}
422 587
423Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) { 588Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
424 fd = socket(TranslateDomain(domain), TranslateType(type), TranslateProtocol(protocol)); 589 fd = socket(TranslateDomainToNative(domain), TranslateTypeToNative(type),
590 TranslateProtocolToNative(protocol));
425 if (fd != INVALID_SOCKET) { 591 if (fd != INVALID_SOCKET) {
426 return Errno::SUCCESS; 592 return Errno::SUCCESS;
427 } 593 }
@@ -430,19 +596,17 @@ Errno Socket::Initialize(Domain domain, Type type, Protocol protocol) {
430} 596}
431 597
432std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() { 598std::pair<SocketBase::AcceptResult, Errno> Socket::Accept() {
433 sockaddr addr; 599 sockaddr_in addr;
434 socklen_t addrlen = sizeof(addr); 600 socklen_t addrlen = sizeof(addr);
435 const SOCKET new_socket = accept(fd, &addr, &addrlen); 601 const SOCKET new_socket = accept(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen);
436 602
437 if (new_socket == INVALID_SOCKET) { 603 if (new_socket == INVALID_SOCKET) {
438 return {AcceptResult{}, GetAndLogLastError()}; 604 return {AcceptResult{}, GetAndLogLastError()};
439 } 605 }
440 606
441 ASSERT(addrlen == sizeof(sockaddr_in));
442
443 AcceptResult result{ 607 AcceptResult result{
444 .socket = std::make_unique<Socket>(new_socket), 608 .socket = std::make_unique<Socket>(new_socket),
445 .sockaddr_in = TranslateToSockAddrIn(addr), 609 .sockaddr_in = TranslateToSockAddrIn(addr, addrlen),
446 }; 610 };
447 611
448 return {std::move(result), Errno::SUCCESS}; 612 return {std::move(result), Errno::SUCCESS};
@@ -458,25 +622,23 @@ Errno Socket::Connect(SockAddrIn addr_in) {
458} 622}
459 623
460std::pair<SockAddrIn, Errno> Socket::GetPeerName() { 624std::pair<SockAddrIn, Errno> Socket::GetPeerName() {
461 sockaddr addr; 625 sockaddr_in addr;
462 socklen_t addrlen = sizeof(addr); 626 socklen_t addrlen = sizeof(addr);
463 if (getpeername(fd, &addr, &addrlen) == SOCKET_ERROR) { 627 if (getpeername(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
464 return {SockAddrIn{}, GetAndLogLastError()}; 628 return {SockAddrIn{}, GetAndLogLastError()};
465 } 629 }
466 630
467 ASSERT(addrlen == sizeof(sockaddr_in)); 631 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
468 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
469} 632}
470 633
471std::pair<SockAddrIn, Errno> Socket::GetSockName() { 634std::pair<SockAddrIn, Errno> Socket::GetSockName() {
472 sockaddr addr; 635 sockaddr_in addr;
473 socklen_t addrlen = sizeof(addr); 636 socklen_t addrlen = sizeof(addr);
474 if (getsockname(fd, &addr, &addrlen) == SOCKET_ERROR) { 637 if (getsockname(fd, reinterpret_cast<sockaddr*>(&addr), &addrlen) == SOCKET_ERROR) {
475 return {SockAddrIn{}, GetAndLogLastError()}; 638 return {SockAddrIn{}, GetAndLogLastError()};
476 } 639 }
477 640
478 ASSERT(addrlen == sizeof(sockaddr_in)); 641 return {TranslateToSockAddrIn(addr, addrlen), Errno::SUCCESS};
479 return {TranslateToSockAddrIn(addr), Errno::SUCCESS};
480} 642}
481 643
482Errno Socket::Bind(SockAddrIn addr) { 644Errno Socket::Bind(SockAddrIn addr) {
@@ -519,7 +681,7 @@ Errno Socket::Shutdown(ShutdownHow how) {
519 return GetAndLogLastError(); 681 return GetAndLogLastError();
520} 682}
521 683
522std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) { 684std::pair<s32, Errno> Socket::Recv(int flags, std::span<u8> message) {
523 ASSERT(flags == 0); 685 ASSERT(flags == 0);
524 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 686 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
525 687
@@ -532,21 +694,20 @@ std::pair<s32, Errno> Socket::Recv(int flags, std::vector<u8>& message) {
532 return {-1, GetAndLogLastError()}; 694 return {-1, GetAndLogLastError()};
533} 695}
534 696
535std::pair<s32, Errno> Socket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 697std::pair<s32, Errno> Socket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
536 ASSERT(flags == 0); 698 ASSERT(flags == 0);
537 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 699 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
538 700
539 sockaddr addr_in{}; 701 sockaddr_in addr_in{};
540 socklen_t addrlen = sizeof(addr_in); 702 socklen_t addrlen = sizeof(addr_in);
541 socklen_t* const p_addrlen = addr ? &addrlen : nullptr; 703 socklen_t* const p_addrlen = addr ? &addrlen : nullptr;
542 sockaddr* const p_addr_in = addr ? &addr_in : nullptr; 704 sockaddr* const p_addr_in = addr ? reinterpret_cast<sockaddr*>(&addr_in) : nullptr;
543 705
544 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()), 706 const auto result = recvfrom(fd, reinterpret_cast<char*>(message.data()),
545 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen); 707 static_cast<int>(message.size()), 0, p_addr_in, p_addrlen);
546 if (result != SOCKET_ERROR) { 708 if (result != SOCKET_ERROR) {
547 if (addr) { 709 if (addr) {
548 ASSERT(addrlen == sizeof(addr_in)); 710 *addr = TranslateToSockAddrIn(addr_in, addrlen);
549 *addr = TranslateToSockAddrIn(addr_in);
550 } 711 }
551 return {static_cast<s32>(result), Errno::SUCCESS}; 712 return {static_cast<s32>(result), Errno::SUCCESS};
552 } 713 }
@@ -597,6 +758,11 @@ Errno Socket::Close() {
597 return Errno::SUCCESS; 758 return Errno::SUCCESS;
598} 759}
599 760
761std::pair<Errno, Errno> Socket::GetPendingError() {
762 auto [pending_err, getsockopt_err] = GetSockOpt<int>(fd, SO_ERROR);
763 return {TranslateNativeError(pending_err), getsockopt_err};
764}
765
600Errno Socket::SetLinger(bool enable, u32 linger) { 766Errno Socket::SetLinger(bool enable, u32 linger) {
601 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger)); 767 return SetSockOpt(fd, SO_LINGER, MakeLinger(enable, linger));
602} 768}
diff --git a/src/core/internal_network/network.h b/src/core/internal_network/network.h
index 1e09a007a..96319bfc8 100644
--- a/src/core/internal_network/network.h
+++ b/src/core/internal_network/network.h
@@ -16,6 +16,11 @@
16#include <netinet/in.h> 16#include <netinet/in.h>
17#endif 17#endif
18 18
19namespace Common {
20template <typename T, typename E>
21class Expected;
22}
23
19namespace Network { 24namespace Network {
20 25
21class SocketBase; 26class SocketBase;
@@ -36,6 +41,26 @@ enum class Errno {
36 NETUNREACH, 41 NETUNREACH,
37 TIMEDOUT, 42 TIMEDOUT,
38 MSGSIZE, 43 MSGSIZE,
44 INPROGRESS,
45 OTHER,
46};
47
48enum class GetAddrInfoError {
49 SUCCESS,
50 ADDRFAMILY,
51 AGAIN,
52 BADFLAGS,
53 FAIL,
54 FAMILY,
55 MEMORY,
56 NODATA,
57 NONAME,
58 SERVICE,
59 SOCKTYPE,
60 SYSTEM,
61 BADHINTS,
62 PROTOCOL,
63 OVERFLOW_,
39 OTHER, 64 OTHER,
40}; 65};
41 66
@@ -49,6 +74,9 @@ enum class PollEvents : u16 {
49 Err = 1 << 3, 74 Err = 1 << 3,
50 Hup = 1 << 4, 75 Hup = 1 << 4,
51 Nval = 1 << 5, 76 Nval = 1 << 5,
77 RdNorm = 1 << 6,
78 RdBand = 1 << 7,
79 WrBand = 1 << 8,
52}; 80};
53 81
54DECLARE_ENUM_FLAG_OPERATORS(PollEvents); 82DECLARE_ENUM_FLAG_OPERATORS(PollEvents);
@@ -82,4 +110,10 @@ constexpr IPv4Address TranslateIPv4(in_addr addr) {
82/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array 110/// @return human ordered IPv4 address (e.g. 192.168.0.1) as an array
83std::optional<IPv4Address> GetHostIPv4Address(); 111std::optional<IPv4Address> GetHostIPv4Address();
84 112
113std::string IPv4AddressToString(IPv4Address ip_addr);
114u32 IPv4AddressToInteger(IPv4Address ip_addr);
115
116Common::Expected<std::vector<AddrInfo>, GetAddrInfoError> GetAddrInfo(
117 const std::string& host, const std::optional<std::string>& service);
118
85} // namespace Network 119} // namespace Network
diff --git a/src/core/internal_network/socket_proxy.cpp b/src/core/internal_network/socket_proxy.cpp
index 7a77171c2..44e9e3093 100644
--- a/src/core/internal_network/socket_proxy.cpp
+++ b/src/core/internal_network/socket_proxy.cpp
@@ -98,7 +98,7 @@ Errno ProxySocket::Shutdown(ShutdownHow how) {
98 return Errno::SUCCESS; 98 return Errno::SUCCESS;
99} 99}
100 100
101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) { 101std::pair<s32, Errno> ProxySocket::Recv(int flags, std::span<u8> message) {
102 LOG_WARNING(Network, "(STUBBED) called"); 102 LOG_WARNING(Network, "(STUBBED) called");
103 ASSERT(flags == 0); 103 ASSERT(flags == 0);
104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 104 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
@@ -106,7 +106,7 @@ std::pair<s32, Errno> ProxySocket::Recv(int flags, std::vector<u8>& message) {
106 return {static_cast<s32>(0), Errno::SUCCESS}; 106 return {static_cast<s32>(0), Errno::SUCCESS};
107} 107}
108 108
109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) { 109std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) {
110 ASSERT(flags == 0); 110 ASSERT(flags == 0);
111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max())); 111 ASSERT(message.size() < static_cast<size_t>(std::numeric_limits<int>::max()));
112 112
@@ -140,8 +140,8 @@ std::pair<s32, Errno> ProxySocket::RecvFrom(int flags, std::vector<u8>& message,
140 } 140 }
141} 141}
142 142
143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& message, 143std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
144 SockAddrIn* addr, std::size_t max_length) { 144 std::size_t max_length) {
145 ProxyPacket& packet = received_packets.front(); 145 ProxyPacket& packet = received_packets.front();
146 if (addr) { 146 if (addr) {
147 addr->family = Domain::INET; 147 addr->family = Domain::INET;
@@ -153,10 +153,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
153 std::size_t read_bytes; 153 std::size_t read_bytes;
154 if (packet.data.size() > max_length) { 154 if (packet.data.size() > max_length) {
155 read_bytes = max_length; 155 read_bytes = max_length;
156 message.clear(); 156 memcpy(message.data(), packet.data.data(), max_length);
157 std::copy(packet.data.begin(), packet.data.begin() + read_bytes,
158 std::back_inserter(message));
159 message.resize(max_length);
160 157
161 if (protocol == Protocol::UDP) { 158 if (protocol == Protocol::UDP) {
162 if (!peek) { 159 if (!peek) {
@@ -171,9 +168,7 @@ std::pair<s32, Errno> ProxySocket::ReceivePacket(int flags, std::vector<u8>& mes
171 } 168 }
172 } else { 169 } else {
173 read_bytes = packet.data.size(); 170 read_bytes = packet.data.size();
174 message.clear(); 171 memcpy(message.data(), packet.data.data(), read_bytes);
175 std::copy(packet.data.begin(), packet.data.end(), std::back_inserter(message));
176 message.resize(max_length);
177 if (!peek) { 172 if (!peek) {
178 received_packets.pop(); 173 received_packets.pop();
179 } 174 }
@@ -293,6 +288,11 @@ Errno ProxySocket::SetNonBlock(bool enable) {
293 return Errno::SUCCESS; 288 return Errno::SUCCESS;
294} 289}
295 290
291std::pair<Errno, Errno> ProxySocket::GetPendingError() {
292 LOG_DEBUG(Network, "(STUBBED) called");
293 return {Errno::SUCCESS, Errno::SUCCESS};
294}
295
296bool ProxySocket::IsOpened() const { 296bool ProxySocket::IsOpened() const {
297 return fd != INVALID_SOCKET; 297 return fd != INVALID_SOCKET;
298} 298}
diff --git a/src/core/internal_network/socket_proxy.h b/src/core/internal_network/socket_proxy.h
index 6e991fa38..e12c413d1 100644
--- a/src/core/internal_network/socket_proxy.h
+++ b/src/core/internal_network/socket_proxy.h
@@ -39,11 +39,11 @@ public:
39 39
40 Errno Shutdown(ShutdownHow how) override; 40 Errno Shutdown(ShutdownHow how) override;
41 41
42 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 42 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
43 43
44 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 44 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
45 45
46 std::pair<s32, Errno> ReceivePacket(int flags, std::vector<u8>& message, SockAddrIn* addr, 46 std::pair<s32, Errno> ReceivePacket(int flags, std::span<u8> message, SockAddrIn* addr,
47 std::size_t max_length); 47 std::size_t max_length);
48 48
49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 49 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
@@ -74,6 +74,8 @@ public:
74 template <typename T> 74 template <typename T>
75 Errno SetSockOpt(SOCKET fd, int option, T value); 75 Errno SetSockOpt(SOCKET fd, int option, T value);
76 76
77 std::pair<Errno, Errno> GetPendingError() override;
78
77 bool IsOpened() const override; 79 bool IsOpened() const override;
78 80
79private: 81private:
diff --git a/src/core/internal_network/sockets.h b/src/core/internal_network/sockets.h
index 11e479e50..46a53ef79 100644
--- a/src/core/internal_network/sockets.h
+++ b/src/core/internal_network/sockets.h
@@ -59,10 +59,9 @@ public:
59 59
60 virtual Errno Shutdown(ShutdownHow how) = 0; 60 virtual Errno Shutdown(ShutdownHow how) = 0;
61 61
62 virtual std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) = 0; 62 virtual std::pair<s32, Errno> Recv(int flags, std::span<u8> message) = 0;
63 63
64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, 64 virtual std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) = 0;
65 SockAddrIn* addr) = 0;
66 65
67 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0; 66 virtual std::pair<s32, Errno> Send(std::span<const u8> message, int flags) = 0;
68 67
@@ -87,6 +86,8 @@ public:
87 86
88 virtual Errno SetNonBlock(bool enable) = 0; 87 virtual Errno SetNonBlock(bool enable) = 0;
89 88
89 virtual std::pair<Errno, Errno> GetPendingError() = 0;
90
90 virtual bool IsOpened() const = 0; 91 virtual bool IsOpened() const = 0;
91 92
92 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0; 93 virtual void HandleProxyPacket(const ProxyPacket& packet) = 0;
@@ -126,9 +127,9 @@ public:
126 127
127 Errno Shutdown(ShutdownHow how) override; 128 Errno Shutdown(ShutdownHow how) override;
128 129
129 std::pair<s32, Errno> Recv(int flags, std::vector<u8>& message) override; 130 std::pair<s32, Errno> Recv(int flags, std::span<u8> message) override;
130 131
131 std::pair<s32, Errno> RecvFrom(int flags, std::vector<u8>& message, SockAddrIn* addr) override; 132 std::pair<s32, Errno> RecvFrom(int flags, std::span<u8> message, SockAddrIn* addr) override;
132 133
133 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override; 134 std::pair<s32, Errno> Send(std::span<const u8> message, int flags) override;
134 135
@@ -156,6 +157,11 @@ public:
156 template <typename T> 157 template <typename T>
157 Errno SetSockOpt(SOCKET fd, int option, T value); 158 Errno SetSockOpt(SOCKET fd, int option, T value);
158 159
160 std::pair<Errno, Errno> GetPendingError() override;
161
162 template <typename T>
163 std::pair<T, Errno> GetSockOpt(SOCKET fd, int option);
164
159 bool IsOpened() const override; 165 bool IsOpened() const override;
160 166
161 void HandleProxyPacket(const ProxyPacket& packet) override; 167 void HandleProxyPacket(const ProxyPacket& packet) override;