summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/core/hle/service/nvdrv/devices/nvmap.cpp72
1 files changed, 60 insertions, 12 deletions
diff --git a/src/core/hle/service/nvdrv/devices/nvmap.cpp b/src/core/hle/service/nvdrv/devices/nvmap.cpp
index a2287cc1b..43651d8a6 100644
--- a/src/core/hle/service/nvdrv/devices/nvmap.cpp
+++ b/src/core/hle/service/nvdrv/devices/nvmap.cpp
@@ -11,6 +11,13 @@
11 11
12namespace Service::Nvidia::Devices { 12namespace Service::Nvidia::Devices {
13 13
14namespace NvErrCodes {
15enum {
16 OperationNotPermitted = -1,
17 InvalidValue = -22,
18};
19}
20
14nvmap::nvmap() = default; 21nvmap::nvmap() = default;
15nvmap::~nvmap() = default; 22nvmap::~nvmap() = default;
16 23
@@ -44,7 +51,11 @@ u32 nvmap::ioctl(Ioctl command, const std::vector<u8>& input, std::vector<u8>& o
44u32 nvmap::IocCreate(const std::vector<u8>& input, std::vector<u8>& output) { 51u32 nvmap::IocCreate(const std::vector<u8>& input, std::vector<u8>& output) {
45 IocCreateParams params; 52 IocCreateParams params;
46 std::memcpy(&params, input.data(), sizeof(params)); 53 std::memcpy(&params, input.data(), sizeof(params));
54 LOG_DEBUG(Service_NVDRV, "size=0x{:08X}", params.size);
47 55
56 if (!params.size) {
57 return static_cast<u32>(NvErrCodes::InvalidValue);
58 }
48 // Create a new nvmap object and obtain a handle to it. 59 // Create a new nvmap object and obtain a handle to it.
49 auto object = std::make_shared<Object>(); 60 auto object = std::make_shared<Object>();
50 object->id = next_id++; 61 object->id = next_id++;
@@ -55,8 +66,6 @@ u32 nvmap::IocCreate(const std::vector<u8>& input, std::vector<u8>& output) {
55 u32 handle = next_handle++; 66 u32 handle = next_handle++;
56 handles[handle] = std::move(object); 67 handles[handle] = std::move(object);
57 68
58 LOG_DEBUG(Service_NVDRV, "size=0x{:08X}", params.size);
59
60 params.handle = handle; 69 params.handle = handle;
61 70
62 std::memcpy(output.data(), &params, sizeof(params)); 71 std::memcpy(output.data(), &params, sizeof(params));
@@ -66,9 +75,29 @@ u32 nvmap::IocCreate(const std::vector<u8>& input, std::vector<u8>& output) {
66u32 nvmap::IocAlloc(const std::vector<u8>& input, std::vector<u8>& output) { 75u32 nvmap::IocAlloc(const std::vector<u8>& input, std::vector<u8>& output) {
67 IocAllocParams params; 76 IocAllocParams params;
68 std::memcpy(&params, input.data(), sizeof(params)); 77 std::memcpy(&params, input.data(), sizeof(params));
78 LOG_DEBUG(Service_NVDRV, "called, addr={:X}", params.addr);
79
80 if (!params.handle) {
81 return static_cast<u32>(NvErrCodes::InvalidValue);
82 }
83
84 if ((params.align - 1) & params.align) {
85 return static_cast<u32>(NvErrCodes::InvalidValue);
86 }
87
88 const u32 min_alignment = 0x1000;
89 if (params.align < min_alignment) {
90 params.align = min_alignment;
91 }
69 92
70 auto object = GetObject(params.handle); 93 auto object = GetObject(params.handle);
71 ASSERT(object); 94 if (!object) {
95 return static_cast<u32>(NvErrCodes::InvalidValue);
96 }
97
98 if (object->status == Object::Status::Allocated) {
99 return static_cast<u32>(NvErrCodes::OperationNotPermitted);
100 }
72 101
73 object->flags = params.flags; 102 object->flags = params.flags;
74 object->align = params.align; 103 object->align = params.align;
@@ -76,8 +105,6 @@ u32 nvmap::IocAlloc(const std::vector<u8>& input, std::vector<u8>& output) {
76 object->addr = params.addr; 105 object->addr = params.addr;
77 object->status = Object::Status::Allocated; 106 object->status = Object::Status::Allocated;
78 107
79 LOG_DEBUG(Service_NVDRV, "called, addr={:X}", params.addr);
80
81 std::memcpy(output.data(), &params, sizeof(params)); 108 std::memcpy(output.data(), &params, sizeof(params));
82 return 0; 109 return 0;
83} 110}
@@ -88,8 +115,14 @@ u32 nvmap::IocGetId(const std::vector<u8>& input, std::vector<u8>& output) {
88 115
89 LOG_WARNING(Service_NVDRV, "called"); 116 LOG_WARNING(Service_NVDRV, "called");
90 117
118 if (!params.handle) {
119 return static_cast<u32>(NvErrCodes::InvalidValue);
120 }
121
91 auto object = GetObject(params.handle); 122 auto object = GetObject(params.handle);
92 ASSERT(object); 123 if (!object) {
124 return static_cast<u32>(NvErrCodes::OperationNotPermitted);
125 }
93 126
94 params.id = object->id; 127 params.id = object->id;
95 128
@@ -105,7 +138,14 @@ u32 nvmap::IocFromId(const std::vector<u8>& input, std::vector<u8>& output) {
105 138
106 auto itr = std::find_if(handles.begin(), handles.end(), 139 auto itr = std::find_if(handles.begin(), handles.end(),
107 [&](const auto& entry) { return entry.second->id == params.id; }); 140 [&](const auto& entry) { return entry.second->id == params.id; });
108 ASSERT(itr != handles.end()); 141 if (itr == handles.end()) {
142 return static_cast<u32>(NvErrCodes::InvalidValue);
143 }
144
145 auto& object = itr->second;
146 if (object->status != Object::Status::Allocated) {
147 return static_cast<u32>(NvErrCodes::InvalidValue);
148 }
109 149
110 itr->second->refcount++; 150 itr->second->refcount++;
111 151
@@ -125,8 +165,13 @@ u32 nvmap::IocParam(const std::vector<u8>& input, std::vector<u8>& output) {
125 LOG_WARNING(Service_NVDRV, "(STUBBED) called type={}", params.param); 165 LOG_WARNING(Service_NVDRV, "(STUBBED) called type={}", params.param);
126 166
127 auto object = GetObject(params.handle); 167 auto object = GetObject(params.handle);
128 ASSERT(object); 168 if (!object) {
129 ASSERT(object->status == Object::Status::Allocated); 169 return static_cast<u32>(NvErrCodes::InvalidValue);
170 }
171
172 if (object->status != Object::Status::Allocated) {
173 return static_cast<u32>(NvErrCodes::OperationNotPermitted);
174 }
130 175
131 switch (static_cast<ParamTypes>(params.param)) { 176 switch (static_cast<ParamTypes>(params.param)) {
132 case ParamTypes::Size: 177 case ParamTypes::Size:
@@ -163,9 +208,12 @@ u32 nvmap::IocFree(const std::vector<u8>& input, std::vector<u8>& output) {
163 LOG_WARNING(Service_NVDRV, "(STUBBED) called"); 208 LOG_WARNING(Service_NVDRV, "(STUBBED) called");
164 209
165 auto itr = handles.find(params.handle); 210 auto itr = handles.find(params.handle);
166 ASSERT(itr != handles.end()); 211 if (itr == handles.end()) {
167 212 return static_cast<u32>(NvErrCodes::InvalidValue);
168 ASSERT(itr->second->refcount > 0); 213 }
214 if (!itr->second->refcount) {
215 return static_cast<u32>(NvErrCodes::InvalidValue);
216 }
169 217
170 itr->second->refcount--; 218 itr->second->refcount--;
171 219