um i think this owrks

This commit is contained in:
Mars 2025-04-23 21:47:25 -04:00
parent cf51e3e569
commit 2219182539
9 changed files with 306 additions and 334 deletions

View file

@ -9,20 +9,57 @@
// clang-format on
#include <cstring>
#include <guiddef.h>
#include <ranges>
#include <winrt/Windows.Foundation.h>
#include <winrt/Windows.Media.Control.h>
#include <winrt/Windows.Storage.h>
#include <winrt/Windows.System.Diagnostics.h>
#include <winrt/Windows.System.Profile.h>
#include <winrt/base.h>
#include <winrt/impl/Windows.Media.Control.2.h>
#include "os.h"
using RtlGetVersionPtr = NTSTATUS(WINAPI*)(PRTL_OSVERSIONINFOW);
// NOLINTBEGIN(*-pro-type-cstyle-cast,*-no-int-to-ptr,*-pro-type-reinterpret-cast)
namespace {
struct OSVersion {
u16 major;
u16 minor;
u16 build;
u16 revision;
static fn parseDeviceFamilyVersion(const winrt::hstring& versionString) -> OSVersion {
try {
const u64 versionUl = std::stoull(winrt::to_string(versionString));
return {
.major = static_cast<u16>((versionUl >> 48) & 0xFFFF),
.minor = static_cast<u16>((versionUl >> 32) & 0xFFFF),
.build = static_cast<u16>((versionUl >> 16) & 0xFFFF),
.revision = static_cast<u16>(versionUl & 0xFFFF),
};
} catch (const std::invalid_argument& e) {
ERROR_LOG("Invalid argument: {}", e.what());
} catch (const std::out_of_range& e) {
ERROR_LOG("Value out of range: {}", e.what());
} catch (const winrt::hresult_error& e) { ERROR_LOG("Windows error: {}", winrt::to_string(e.message())); }
return { .major = 0, .minor = 0, .build = 0, .revision = 0 };
}
};
// clang-format off
constexpr Array<Pair<StringView, StringView>, 3> windowsShellMap = {{
{ "cmd", "Command Prompt" },
{ "powershell", "PowerShell" },
{ "pwsh", "PowerShell Core" },
}};
constexpr Array<Pair<StringView, StringView>, 3> msysShellMap = {{
{ "bash", "Bash" },
{ "zsh", "Zsh" },
{ "fish", "Fish" },
}};
// clang-format on
class ProcessSnapshot {
public:
ProcessSnapshot() : h_snapshot(CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0)) {}
@ -39,8 +76,8 @@ namespace {
[[nodiscard]] fn isValid() const -> bool { return h_snapshot != INVALID_HANDLE_VALUE; }
[[nodiscard]] fn getProcesses() const -> std::vector<std::pair<DWORD, String>> {
std::vector<std::pair<DWORD, String>> processes;
[[nodiscard]] fn getProcesses() const -> Vec<Pair<DWORD, String>> {
Vec<Pair<DWORD, String>> processes;
if (!isValid())
return processes;
@ -51,12 +88,9 @@ namespace {
if (!Process32First(h_snapshot, &pe32))
return processes;
// Get first process
if (Process32First(h_snapshot, &pe32)) {
// Add first process to vector
processes.emplace_back(pe32.th32ProcessID, String(reinterpret_cast<const char*>(pe32.szExeFile)));
// Add remaining processes
while (Process32Next(h_snapshot, &pe32))
processes.emplace_back(pe32.th32ProcessID, String(reinterpret_cast<const char*>(pe32.szExeFile)));
}
@ -79,10 +113,9 @@ namespace {
return "";
}
// For string values, allocate one less byte to avoid the null terminator
String value((type == REG_SZ || type == REG_EXPAND_SZ) ? dataSize - 1 : dataSize, '\0');
if (RegQueryValueExA(key, valueName.c_str(), nullptr, nullptr, std::bit_cast<LPBYTE>(value.data()), &dataSize) !=
if (RegQueryValueExA(key, valueName.c_str(), nullptr, nullptr, reinterpret_cast<LPBYTE>(value.data()), &dataSize) !=
ERROR_SUCCESS) {
RegCloseKey(key);
return "";
@ -92,12 +125,12 @@ namespace {
return value;
}
fn GetProcessInfo() -> std::vector<std::pair<DWORD, String>> {
fn GetProcessInfo() -> Vec<Pair<DWORD, String>> {
const ProcessSnapshot snapshot;
return snapshot.isValid() ? snapshot.getProcesses() : std::vector<std::pair<DWORD, String>> {};
}
fn IsProcessRunning(const std::vector<String>& processes, const String& name) -> bool {
fn IsProcessRunning(const Vec<String>& processes, const String& name) -> bool {
return std::ranges::any_of(processes, [&name](const String& proc) -> bool {
return _stricmp(proc.c_str(), name.c_str()) == 0;
});
@ -143,139 +176,130 @@ namespace {
return "";
}
template <usize sz>
fn FindShellInProcessTree(const DWORD startPid, const Array<Pair<StringView, StringView>, sz>& shellMap)
-> std::optional<String> {
DWORD pid = startPid;
while (pid != 0) {
String processName = GetProcessName(pid);
if (processName.empty()) {
pid = GetParentProcessId(pid);
continue;
}
std::ranges::transform(processName, processName.begin(), [](const u8 character) {
return static_cast<char>(std::tolower(static_cast<unsigned char>(character)));
});
if (processName.length() > 4 && processName.substr(processName.length() - 4) == ".exe")
processName.resize(processName.length() - 4);
auto iter = std::ranges::find_if(shellMap, [&](const auto& pair) {
return std::string_view { processName } == pair.first;
});
if (iter != std::ranges::end(shellMap))
return String { iter->second };
pid = GetParentProcessId(pid);
}
return std::nullopt;
}
fn GetBuildNumber() -> Option<u64> {
try {
using namespace winrt::Windows::System::Profile;
const auto versionInfo = AnalyticsInfo::VersionInfo();
const winrt::hstring familyVersion = versionInfo.DeviceFamilyVersion();
if (!familyVersion.empty()) {
const u64 versionUl = std::stoull(winrt::to_string(familyVersion));
return (versionUl >> 16) & 0xFFFF;
}
} catch (const winrt::hresult_error& e) {
DEBUG_LOG("WinRT error getting build number: {}", winrt::to_string(e.message()));
} catch (const Exception& e) { DEBUG_LOG("Standard exception getting build number: {}", e.what()); }
return None;
}
}
fn GetMemInfo() -> Result<u64, String> {
fn os::GetMemInfo() -> Result<u64, String> {
try {
using namespace winrt::Windows::System::Diagnostics;
const SystemDiagnosticInfo diag = SystemDiagnosticInfo::GetForCurrentSystem();
return diag.MemoryUsage().GetReport().TotalPhysicalSizeInBytes();
return winrt::Windows::System::Diagnostics::SystemDiagnosticInfo::GetForCurrentSystem()
.MemoryUsage()
.GetReport()
.TotalPhysicalSizeInBytes();
} catch (const winrt::hresult_error& e) {
return Err(std::format("Failed to get memory info: {}", to_string(e.message())));
}
}
fn GetNowPlaying() -> Result<String, NowPlayingError> {
fn os::GetNowPlaying() -> Result<String, NowPlayingError> {
using namespace winrt::Windows::Media::Control;
using namespace winrt::Windows::Foundation;
using MediaProperties = GlobalSystemMediaTransportControlsSessionMediaProperties;
using Session = GlobalSystemMediaTransportControlsSession;
using SessionManager = GlobalSystemMediaTransportControlsSessionManager;
using Session = GlobalSystemMediaTransportControlsSession;
using SessionManager = GlobalSystemMediaTransportControlsSessionManager;
try {
// Request the session manager asynchronously
const IAsyncOperation<SessionManager> sessionManagerOp = SessionManager::RequestAsync();
const SessionManager sessionManager = sessionManagerOp.get();
if (const Session currentSession = sessionManager.GetCurrentSession()) {
// Try to get the media properties asynchronously
const MediaProperties mediaProperties = currentSession.TryGetMediaPropertiesAsync().get();
if (const Session currentSession = sessionManager.GetCurrentSession())
return winrt::to_string(currentSession.TryGetMediaPropertiesAsync().get().Title());
// Convert the hstring title to string
return to_string(mediaProperties.Title());
}
// If we reach this point, there is no current session
return Err(NowPlayingCode::NoActivePlayer);
} catch (const winrt::hresult_error& e) { return Err(e); }
}
fn GetOSVersion() -> Result<String, String> {
constexpr OSVERSIONINFOEXW osvi = { sizeof(OSVERSIONINFOEXW), 0, 0, 0, 0, { 0 }, 0, 0, 0, 0, 0 };
NTSTATUS status = 0;
fn os::GetOSVersion() -> Result<String, String> {
try {
const String regSubKey = R"(SOFTWARE\Microsoft\Windows NT\CurrentVersion)";
if (const HMODULE ntdllHandle = GetModuleHandleW(L"ntdll.dll"))
if (const auto rtlGetVersion = std::bit_cast<RtlGetVersionPtr>(GetProcAddress(ntdllHandle, "RtlGetVersion")))
status = rtlGetVersion(std::bit_cast<PRTL_OSVERSIONINFOW>(&osvi));
String productName = GetRegistryValue(HKEY_LOCAL_MACHINE, regSubKey, "ProductName");
const String displayVersion = GetRegistryValue(HKEY_LOCAL_MACHINE, regSubKey, "DisplayVersion");
String productName;
String edition;
if (productName.empty())
return Err("Failed to read ProductName");
if (status == 0) {
DWORD productType = 0;
if (GetProductInfo(
osvi.dwMajorVersion, osvi.dwMinorVersion, osvi.wServicePackMajor, osvi.wServicePackMinor, &productType
)) {
if (osvi.dwMajorVersion == 10) {
if (osvi.dwBuildNumber >= 22000)
productName = "Windows 11";
else
productName = "Windows 10";
if (const Option<u64> buildNumber = GetBuildNumber()) {
if (*buildNumber >= 22000)
if (const usize pos = productName.find("Windows 10");
pos != String::npos && (pos == 0 || !isalnum(static_cast<u8>(productName[pos - 1]))) &&
(pos + 10 == productName.length() || !isalnum(static_cast<u8>(productName[pos + 10]))))
productName.replace(pos, 10, "Windows 11");
} else
DEBUG_LOG("Warning: Could not get build number via WinRT; Win11 patch relies on registry ProductName only.");
switch (productType) {
case PRODUCT_PROFESSIONAL:
edition = " Pro";
break;
case PRODUCT_ENTERPRISE:
edition = " Enterprise";
break;
case PRODUCT_EDUCATION:
edition = " Education";
break;
case PRODUCT_HOME_BASIC:
case PRODUCT_HOME_PREMIUM:
edition = " Home";
break;
case PRODUCT_CLOUDEDITION:
edition = " Cloud";
break;
default:
break;
}
}
}
} else {
productName =
GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SOFTWARE\Microsoft\Windows NT\CurrentVersion)", "ProductName");
if (const i32 buildNumber = stoi(
GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SOFTWARE\Microsoft\Windows NT\CurrentVersion)", "CurrentBuildNumber")
);
buildNumber >= 22000 && productName.find("Windows 10") != String::npos)
productName.replace(productName.find("Windows 10"), 10, "Windows 11");
}
if (!productName.empty()) {
String result = productName + edition;
const String displayVersion =
GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SOFTWARE\Microsoft\Windows NT\CurrentVersion)", "DisplayVersion");
if (!displayVersion.empty())
result += " " + displayVersion;
return result;
}
return "Windows";
return displayVersion.empty() ? productName : productName + " " + displayVersion;
} catch (const Exception& e) { return Err(std::format("Exception occurred getting OS version: {}", e.what())); }
}
fn GetHost() -> String {
String hostName = GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SYSTEM\HardwareConfig\Current)", "SystemFamily");
return hostName;
fn os::GetHost() -> String {
return GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SYSTEM\HardwareConfig\Current)", "SystemFamily");
}
fn GetKernelVersion() -> String {
// ReSharper disable once CppLocalVariableMayBeConst
if (HMODULE ntdllHandle = GetModuleHandleW(L"ntdll.dll")) {
if (const auto rtlGetVersion = std::bit_cast<RtlGetVersionPtr>(GetProcAddress(ntdllHandle, "RtlGetVersion"))) {
RTL_OSVERSIONINFOW osInfo = {};
osInfo.dwOSVersionInfoSize = sizeof(osInfo);
fn os::GetKernelVersion() -> String {
try {
using namespace winrt::Windows::System::Profile;
if (rtlGetVersion(&osInfo) == 0) {
return std::format(
"{}.{}.{}.{}", osInfo.dwMajorVersion, osInfo.dwMinorVersion, osInfo.dwBuildNumber, osInfo.dwPlatformId
);
}
}
}
const AnalyticsVersionInfo versionInfo = AnalyticsInfo::VersionInfo();
if (const winrt::hstring familyVersion = versionInfo.DeviceFamilyVersion(); !familyVersion.empty())
if (auto [major, minor, build, revision] = OSVersion::parseDeviceFamilyVersion(familyVersion); build > 0)
return std::format("{}.{}.{}.{}", major, minor, build, revision);
} catch (const winrt::hresult_error& e) {
ERROR_LOG("WinRT error: {}", winrt::to_string(e.message()));
} catch (const Exception& e) { ERROR_LOG("Failed to get kernel version: {}", e.what()); }
return "";
}
fn GetWindowManager() -> String {
fn os::GetWindowManager() -> String {
const auto processInfo = GetProcessInfo();
std::vector<String> processNames;
@ -300,7 +324,7 @@ fn GetWindowManager() -> String {
return "Windows Manager";
}
fn GetDesktopEnvironment() -> Option<String> {
fn os::GetDesktopEnvironment() -> Option<String> {
const String buildStr =
GetRegistryValue(HKEY_LOCAL_MACHINE, R"(SOFTWARE\Microsoft\Windows NT\CurrentVersion)", "CurrentBuildNumber");
@ -347,97 +371,45 @@ fn GetDesktopEnvironment() -> Option<String> {
}
}
fn GetShell() -> String {
// TODO: update this to use GetEnv
fn os::GetShell() -> String {
const DWORD currentPid = GetCurrentProcessId();
const std::unordered_map<String, String> knownShells = {
{ "cmd.exe", "Command Prompt" },
{ "powershell.exe", "PowerShell" },
{ "pwsh.exe", "PowerShell Core" },
{ "windowsterminal.exe", "Windows Terminal" },
{ "mintty.exe", "Mintty" },
{ "bash.exe", "Windows Subsystem for Linux" }
};
if (const Result<String, EnvError> msystemResult = GetEnv("MSYSTEM")) {
String shellPath;
if (const Result<String, EnvError> shellResult = GetEnv("SHELL"); !shellResult->empty())
shellPath = *shellResult;
else if (const Result<String, EnvError> loginShellResult = GetEnv("LOGINSHELL"); !loginShellResult->empty())
shellPath = *loginShellResult;
char* msystemEnv = nullptr;
if (_dupenv_s(&msystemEnv, nullptr, "MSYSTEM") == 0 && msystemEnv != nullptr) {
const std::unique_ptr<char, decltype(&free)> msystemEnvGuard(msystemEnv, free);
if (!shellPath.empty()) {
const usize lastSlash = shellPath.find_last_of("\\/");
String shellExe = (lastSlash != String::npos) ? shellPath.substr(lastSlash + 1) : shellPath;
char* shell = nullptr;
size_t shellLen = 0;
_dupenv_s(&shell, &shellLen, "SHELL");
const std::unique_ptr<char, decltype(&free)> shellGuard(shell, free);
std::ranges::transform(shellExe, shellExe.begin(), [](const u8 c) { return std::tolower(c); });
if (!shell || strlen(shell) == 0) {
char* loginShell = nullptr;
size_t loginShellLen = 0;
_dupenv_s(&loginShell, &loginShellLen, "LOGINSHELL");
const std::unique_ptr<char, decltype(&free)> loginShellGuard(loginShell, free);
shell = loginShell;
if (shellExe.ends_with(".exe"))
shellExe.resize(shellExe.length() - 4);
const auto iter =
std::ranges::find_if(msysShellMap, [&](const auto& pair) { return StringView { shellExe } == pair.first; });
if (iter != std::ranges::end(msysShellMap))
return String { iter->second };
}
if (shell) {
String shellExe;
const String shellPath = shell;
const size_t lastSlash = shellPath.find_last_of("\\/");
shellExe = (lastSlash != String::npos) ? shellPath.substr(lastSlash + 1) : shellPath;
std::ranges::transform(shellExe, shellExe.begin(), ::tolower);
if (const Option<String> msysShell = FindShellInProcessTree(currentPid, msysShellMap))
return *msysShell;
// Use a map for shell name lookup instead of multiple if statements
const std::unordered_map<StringView, String> shellNames = {
{ "bash", "Bash" },
{ "zsh", "Zsh" },
{ "fish", "Fish" }
};
for (const auto& [pattern, name] : shellNames) {
if (shellExe.find(pattern) != String::npos)
return name;
}
return shellExe.empty() ? "MSYS2" : "MSYS2/" + shellExe;
}
const auto processInfo = GetProcessInfo();
DWORD pid = GetCurrentProcessId();
while (pid != 0) {
String processName = GetProcessName(pid);
std::ranges::transform(processName, processName.begin(), ::tolower);
const std::unordered_map<String, String> msysShells = {
{ "bash.exe", "Bash" },
{ "zsh.exe", "Zsh" },
{ "fish.exe", "Fish" },
{ "mintty.exe", "Mintty" }
};
for (const auto& [msysShellExe, shellName] : msysShells) {
if (processName == msysShellExe)
return shellName;
}
pid = GetParentProcessId(pid);
}
return "MSYS2";
return "MSYS2 Environment";
}
DWORD pid = GetCurrentProcessId();
while (pid != 0) {
String processName = GetProcessName(pid);
std::ranges::transform(processName, processName.begin(), ::tolower);
if (const Option<String> windowsShell = FindShellInProcessTree(currentPid, windowsShellMap))
return *windowsShell;
if (auto shellIterator = knownShells.find(processName); shellIterator != knownShells.end())
return shellIterator->second;
pid = GetParentProcessId(pid);
}
return "Windows Console";
return "Unknown Shell";
}
fn GetDiskUsage() -> std::pair<u64, u64> {
fn os::GetDiskUsage() -> Pair<u64, u64> {
ULARGE_INTEGER freeBytes, totalBytes;
if (GetDiskFreeSpaceExW(L"C:\\", nullptr, &totalBytes, &freeBytes))
@ -445,6 +417,5 @@ fn GetDiskUsage() -> std::pair<u64, u64> {
return { 0, 0 };
}
// NOLINTEND(*-pro-type-cstyle-cast,*-no-int-to-ptr,*-pro-type-reinterpret-cast)
#endif