diff --git hadoop-common-project/hadoop-common/src/main/winutils/chown.c hadoop-common-project/hadoop-common/src/main/winutils/chown.c index bc2aefc..1be8121 100644 --- hadoop-common-project/hadoop-common/src/main/winutils/chown.c +++ hadoop-common-project/hadoop-common/src/main/winutils/chown.c @@ -63,11 +63,11 @@ static DWORD ChangeFileOwnerBySid(__in LPCWSTR path, // SID is not contained in the caller's token, and have the SE_GROUP_OWNER // permission enabled. // - if (!EnablePrivilege(L"SeTakeOwnershipPrivilege")) + if (EnablePrivilege(L"SeTakeOwnershipPrivilege") != ERROR_SUCCESS) { fwprintf(stdout, L"INFO: The user does not have SeTakeOwnershipPrivilege.\n"); } - if (!EnablePrivilege(L"SeRestorePrivilege")) + if (EnablePrivilege(L"SeRestorePrivilege") != ERROR_SUCCESS) { fwprintf(stdout, L"INFO: The user does not have SeRestorePrivilege.\n"); } diff --git hadoop-common-project/hadoop-common/src/main/winutils/include/winutils.h hadoop-common-project/hadoop-common/src/main/winutils/include/winutils.h index 1c0007a..bae754c 100644 --- hadoop-common-project/hadoop-common/src/main/winutils/include/winutils.h +++ hadoop-common-project/hadoop-common/src/main/winutils/include/winutils.h @@ -27,6 +27,8 @@ #include #include #include +#include +#include enum EXIT_CODE { @@ -153,6 +155,26 @@ DWORD ChangeFileModeByMask(__in LPCWSTR path, INT mode); DWORD GetLocalGroupsForUser(__in LPCWSTR user, __out LPLOCALGROUP_USERS_INFO_0 *groups, __out LPDWORD entries); -BOOL EnablePrivilege(__in LPCWSTR privilegeName); - void GetLibraryName(__in LPCVOID lpAddress, __out LPWSTR *filename); + +DWORD EnablePrivilege(__in LPCWSTR privilegeName); + +void AssignLsaString(__inout LSA_STRING * target, __in const char *strBuf); + +DWORD RegisterWithLsa(__in const char *logonProcessName, __out HANDLE * lsaHandle); + +void UnregisterWithLsa(__in HANDLE lsaHandle); + +DWORD LookupKerberosAuthenticationPackageId(__in HANDLE lsaHandle, __out ULONG * packageId); + +DWORD CreateLogonForUser(__in HANDLE lsaHandle, + __in const char * tokenSourceName, + __in const char * tokenOriginName, + __in ULONG authnPkgId, + __in const wchar_t* principalName, + __out HANDLE *tokenHandle); + +DWORD LoadUserProfileForLogon(__in HANDLE logonHandle, __out PROFILEINFO * pi); + +DWORD UnloadProfileForLogon(__in HANDLE logonHandle, __in PROFILEINFO * pi); + diff --git hadoop-common-project/hadoop-common/src/main/winutils/libwinutils.c hadoop-common-project/hadoop-common/src/main/winutils/libwinutils.c index 391247f..3de458c 100644 --- hadoop-common-project/hadoop-common/src/main/winutils/libwinutils.c +++ hadoop-common-project/hadoop-common/src/main/winutils/libwinutils.c @@ -17,6 +17,8 @@ #pragma comment(lib, "authz.lib") #pragma comment(lib, "netapi32.lib") +#pragma comment(lib, "Secur32.lib") +#pragma comment(lib, "Userenv.lib") #include "winutils.h" #include #include @@ -235,10 +237,10 @@ ConvertToLongPathExit: // Function: IsDirFileInfo // // Description: -// Test if the given file information is a directory +// Test if the given file information is a directory // // Returns: -// TRUE if it is a directory +// TRUE if it is a directory // FALSE otherwise // // Notes: @@ -255,10 +257,10 @@ BOOL IsDirFileInfo(const BY_HANDLE_FILE_INFORMATION *fileInformation) // Function: CheckFileAttributes // // Description: -// Check if the given file has all the given attribute(s) +// Check if the given file has all the given attribute(s) // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -279,10 +281,10 @@ static DWORD FileAttributesCheck( // Function: IsDirectory // // Description: -// Check if the given file is a directory +// Check if the given file is a directory // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -296,10 +298,10 @@ DWORD DirectoryCheck(__in LPCWSTR pathName, __out PBOOL res) // Function: IsReparsePoint // // Description: -// Check if the given file is a reparse point +// Check if the given file is a reparse point // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -313,10 +315,10 @@ static DWORD ReparsePointCheck(__in LPCWSTR pathName, __out PBOOL res) // Function: CheckReparseTag // // Description: -// Check if the given file is a reparse point of the given tag. +// Check if the given file is a reparse point of the given tag. // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -354,10 +356,10 @@ static DWORD ReparseTagCheck(__in LPCWSTR path, __in DWORD tag, __out PBOOL res) // Function: IsSymbolicLink // // Description: -// Check if the given file is a symbolic link. +// Check if the given file is a symbolic link. // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -371,10 +373,10 @@ DWORD SymbolicLinkCheck(__in LPCWSTR pathName, __out PBOOL res) // Function: IsJunctionPoint // // Description: -// Check if the given file is a junction point. +// Check if the given file is a junction point. // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // error code otherwise // // Notes: @@ -388,14 +390,14 @@ DWORD JunctionPointCheck(__in LPCWSTR pathName, __out PBOOL res) // Function: GetSidFromAcctNameW // // Description: -// To retrieve the SID for a user account +// To retrieve the SID for a user account // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // Other error code: otherwise // // Notes: -// Caller needs to destroy the memory of Sid by calling LocalFree() +// Caller needs to destroy the memory of Sid by calling LocalFree() // DWORD GetSidFromAcctNameW(__in PCWSTR acctName, __out PSID *ppSid) { @@ -477,10 +479,10 @@ DWORD GetSidFromAcctNameW(__in PCWSTR acctName, __out PSID *ppSid) // Function: GetUnixAccessMask // // Description: -// Compute the 3 bit Unix mask for the owner, group, or, others +// Compute the 3 bit Unix mask for the owner, group, or, others // // Returns: -// The 3 bit Unix mask in INT +// The 3 bit Unix mask in INT // // Notes: // @@ -504,10 +506,10 @@ static INT GetUnixAccessMask(ACCESS_MASK Mask) // Function: GetAccess // // Description: -// Get Windows acces mask by AuthZ methods +// Get Windows acces mask by AuthZ methods // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // // Notes: // @@ -552,10 +554,10 @@ static DWORD GetAccess(AUTHZ_CLIENT_CONTEXT_HANDLE hAuthzClient, // Function: GetEffectiveRightsForSid // // Description: -// Get Windows acces mask by AuthZ methods +// Get Windows acces mask by AuthZ methods // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // // Notes: // We run into problems for local user accounts when using the method @@ -712,11 +714,11 @@ CheckAccessEnd: // Function: FindFileOwnerAndPermissionByHandle // // Description: -// Find the owner, primary group and permissions of a file object given the +// Find the owner, primary group and permissions of a file object given the // the file object handle. The function will always follow symbolic links. // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // Error code otherwise // // Notes: @@ -776,10 +778,10 @@ FindFileOwnerAndPermissionByHandleEnd: // Function: FindFileOwnerAndPermission // // Description: -// Find the owner, primary group and permissions of a file object +// Find the owner, primary group and permissions of a file object // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // Error code otherwise // // Notes: @@ -796,8 +798,7 @@ DWORD FindFileOwnerAndPermission( __out_opt LPWSTR *pGroupName, __out_opt PINT pMask) { - DWORD dwRtnCode = 0; - + DWORD dwRtnCode = 0; PSECURITY_DESCRIPTOR pSd = NULL; PSID psidOwner = NULL; @@ -1439,14 +1440,14 @@ ChangeFileModeByMaskEnd: // Function: GetAccntNameFromSid // // Description: -// To retrieve an account name given the SID +// To retrieve an account name given the SID // // Returns: -// ERROR_SUCCESS: on success +// ERROR_SUCCESS: on success // Other error code: otherwise // // Notes: -// Caller needs to destroy the memory of account name by calling LocalFree() +// Caller needs to destroy the memory of account name by calling LocalFree() // DWORD GetAccntNameFromSid(__in PSID pSid, __out PWSTR *ppAcctName) { @@ -1535,10 +1536,10 @@ GetAccntNameFromSidEnd: // Function: GetLocalGroupsForUser // // Description: -// Get an array of groups for the given user. +// Get an array of groups for the given user. // // Returns: -// ERROR_SUCCESS on success +// ERROR_SUCCESS on success // Other error code on failure // // Notes: @@ -1634,15 +1635,16 @@ GetLocalGroupsForUserEnd: // Function: EnablePrivilege // // Description: -// Check if the process has the given privilege. If yes, enable the privilege +// Check if the process has the given privilege. If yes, enable the privilege // to the process's access token. // // Returns: -// TRUE: on success +// ERROR_SUCCESS on success +// GetLastError() on error // // Notes: // -BOOL EnablePrivilege(__in LPCWSTR privilegeName) +DWORD EnablePrivilege(__in LPCWSTR privilegeName) { HANDLE hToken = INVALID_HANDLE_VALUE; TOKEN_PRIVILEGES tp = { 0 }; @@ -1651,28 +1653,31 @@ BOOL EnablePrivilege(__in LPCWSTR privilegeName) if (!OpenProcessToken(GetCurrentProcess(), TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &hToken)) { - ReportErrorCode(L"OpenProcessToken", GetLastError()); - return FALSE; + dwErrCode = GetLastError(); + ReportErrorCode(L"OpenProcessToken", dwErrCode); + return dwErrCode; } tp.PrivilegeCount = 1; if (!LookupPrivilegeValueW(NULL, privilegeName, &(tp.Privileges[0].Luid))) { - ReportErrorCode(L"LookupPrivilegeValue", GetLastError()); + dwErrCode = GetLastError(); + ReportErrorCode(L"LookupPrivilegeValue", dwErrCode); CloseHandle(hToken); - return FALSE; + return dwErrCode; } tp.Privileges[0].Attributes = SE_PRIVILEGE_ENABLED; // As stated on MSDN, we need to use GetLastError() to check if // AdjustTokenPrivileges() adjusted all of the specified privileges. // - AdjustTokenPrivileges(hToken, FALSE, &tp, 0, NULL, NULL); - dwErrCode = GetLastError(); + if( !AdjustTokenPrivileges(hToken, FALSE, &tp, 0, NULL, NULL) ) { + dwErrCode = GetLastError(); + } CloseHandle(hToken); - return dwErrCode == ERROR_SUCCESS; + return dwErrCode; } //---------------------------------------------------------------------------- @@ -1716,9 +1721,6 @@ void ReportErrorCode(LPCWSTR func, DWORD err) // Description: // Given an address, get the file name of the library from which it was loaded. // -// Returns: -// None -// // Notes: // - The function allocates heap memory and points the filename out parameter to // the newly allocated memory, which will contain the name of the file. @@ -1757,3 +1759,290 @@ cleanup: *filename = NULL; } } + +// Function: AssignLsaString +// +// Description: +// fills in values of LSA_STRING struct to point to a string buffer +// +// Returns: +// None +// +// IMPORTANT*** strBuf is not copied. It must be globally immutable +// +void AssignLsaString(__inout LSA_STRING * target, __in const char *strBuf) +{ + target->Length = (USHORT)(sizeof(char)*strlen(strBuf)); + target->MaximumLength = target->Length; + target->Buffer = (char *)(strBuf); +} + +//---------------------------------------------------------------------------- +// Function: RegisterWithLsa +// +// Description: +// Registers with local security authority and sets handle for use in later LSA +// operations +// +// Returns: +// ERROR_SUCCESS on success +// Other error code on failure +// +// Notes: +// +DWORD RegisterWithLsa(__in const char *logonProcessName, __out HANDLE * lsaHandle) +{ + LSA_STRING processName; + LSA_OPERATIONAL_MODE o_mode; // never useful as per msdn docs + NTSTATUS registerStatus; + *lsaHandle = 0; + + AssignLsaString(&processName, logonProcessName); + registerStatus = LsaRegisterLogonProcess(&processName, lsaHandle, &o_mode); + + return LsaNtStatusToWinError( registerStatus ); +} + +//---------------------------------------------------------------------------- +// Function: UnregisterWithLsa +// +// Description: +// Closes LSA handle allocated by RegisterWithLsa() +// +// Returns: +// None +// +// Notes: +// +void UnregisterWithLsa(__in HANDLE lsaHandle) +{ + LsaClose(lsaHandle); +} + +//---------------------------------------------------------------------------- +// Function: LookupKerberosAuthenticationPackageId +// +// Description: +// Looks of the current id (integer index) of the Kerberos authentication package on the local +// machine. +// +// Returns: +// ERROR_SUCCESS on success +// Other error code on failure +// +// Notes: +// +DWORD LookupKerberosAuthenticationPackageId(__in HANDLE lsaHandle, __out ULONG * packageId) +{ + NTSTATUS lookupStatus; + LSA_STRING pkgName; + + AssignLsaString(&pkgName, MICROSOFT_KERBEROS_NAME_A); + lookupStatus = LsaLookupAuthenticationPackage(lsaHandle, &pkgName, packageId); + return LsaNtStatusToWinError( lookupStatus ); +} + +//---------------------------------------------------------------------------- +// Function: CreateLogonForUser +// +// Description: +// Contacts the local LSA and performs a logon without credential for the +// given principal. This logon token will be local machine only and have no +// network credentials attached. +// +// Returns: +// ERROR_SUCCESS on success +// Other error code on failure +// +// Notes: +// This call assumes that all required privileges have already been enabled (TCB etc). +// IMPORTANT **** tokenOriginName must be immutable! +// +DWORD CreateLogonForUser(__in HANDLE lsaHandle, + __in const char * tokenSourceName, + __in const char * tokenOriginName, // must be immutable, will not be copied! + __in ULONG authnPkgId, + __in const wchar_t* principalName, + __out HANDLE *tokenHandle) +{ + DWORD logonStatus = ERROR_ASSERTION_FAILURE; // Failure to set status should trigger error + TOKEN_SOURCE tokenSource; + LSA_STRING originName; + void * profile = NULL; + + // from MSDN: + // The ClientUpn and ClientRealm members of the KERB_S4U_LOGON + // structure must point to buffers in memory that are contiguous + // to the structure itself. The value of the + // AuthenticationInformationLength parameter must take into + // account the length of these buffers. + const int principalNameBufLen = lstrlen(principalName)*sizeof(*principalName); + const int totalAuthInfoLen = sizeof(KERB_S4U_LOGON) + principalNameBufLen; + KERB_S4U_LOGON* s4uLogonAuthInfo = (KERB_S4U_LOGON*)calloc(totalAuthInfoLen, 1); + if (s4uLogonAuthInfo == NULL ) { + logonStatus = ERROR_NOT_ENOUGH_MEMORY; + goto done; + } + s4uLogonAuthInfo->MessageType = KerbS4ULogon; + s4uLogonAuthInfo->ClientUpn.Buffer = (wchar_t*)((char*)s4uLogonAuthInfo + sizeof *s4uLogonAuthInfo); + CopyMemory(s4uLogonAuthInfo->ClientUpn.Buffer, principalName, principalNameBufLen); + s4uLogonAuthInfo->ClientUpn.Length = (USHORT)principalNameBufLen; + s4uLogonAuthInfo->ClientUpn.MaximumLength = (USHORT)principalNameBufLen; + + AllocateLocallyUniqueId(&tokenSource.SourceIdentifier); + StringCchCopyA(tokenSource.SourceName, TOKEN_SOURCE_LENGTH, tokenSourceName ); + AssignLsaString(&originName, tokenOriginName); + + { + DWORD cbProfile = 0; + LUID logonId; + QUOTA_LIMITS quotaLimits; + NTSTATUS subStatus; + + NTSTATUS logonNtStatus = LsaLogonUser(lsaHandle, + &originName, + Batch, // SECURITY_LOGON_TYPE + authnPkgId, + s4uLogonAuthInfo, + totalAuthInfoLen, + 0, + &tokenSource, + &profile, + &cbProfile, + &logonId, + tokenHandle, + "aLimits, + &subStatus); + logonStatus = LsaNtStatusToWinError( logonNtStatus ); + } +done: + // clean up + if (s4uLogonAuthInfo != NULL) { + free(s4uLogonAuthInfo); + } + if (profile != NULL) { + LsaFreeReturnBuffer(profile); + } + return logonStatus; +} + +// NOTE: must free allocatedName +DWORD GetNameFromLogonToken(__in HANDLE logonToken, __out wchar_t **allocatedName) +{ + DWORD userInfoSize = 0; + PTOKEN_USER user = NULL; + DWORD userNameSize = 0; + wchar_t * userName = NULL; + DWORD domainNameSize = 0; + wchar_t * domainName = NULL; + SID_NAME_USE sidUse = SidTypeUnknown; + DWORD getNameStatus = ERROR_ASSERTION_FAILURE; // Failure to set status should trigger error + BOOL tokenInformation = FALSE; + + // call for sid size then alloc and call for sid + tokenInformation = GetTokenInformation(logonToken, TokenUser, NULL, 0, &userInfoSize); + assert (FALSE == tokenInformation); + + // last call should have failed and filled in allocation size + if ((getNameStatus = GetLastError()) != ERROR_INSUFFICIENT_BUFFER) + { + goto done; + } + user = (PTOKEN_USER)calloc(userInfoSize,1); + if (user == NULL) + { + getNameStatus = ERROR_NOT_ENOUGH_MEMORY; + goto done; + } + if (!GetTokenInformation(logonToken, TokenUser, user, userInfoSize, &userInfoSize)) { + getNameStatus = GetLastError(); + goto done; + } + LookupAccountSid( NULL, user->User.Sid, NULL, &userNameSize, NULL, &domainNameSize, &sidUse ); + // last call should have failed and filled in allocation size + if ((getNameStatus = GetLastError()) != ERROR_INSUFFICIENT_BUFFER) + { + goto done; + } + userName = (wchar_t *)calloc(userNameSize, sizeof(wchar_t)); + if (userName == NULL) { + getNameStatus = ERROR_NOT_ENOUGH_MEMORY; + goto done; + } + domainName = (wchar_t *)calloc(domainNameSize, sizeof(wchar_t)); + if (domainName == NULL) { + getNameStatus = ERROR_NOT_ENOUGH_MEMORY; + goto done; + } + if (!LookupAccountSid( NULL, user->User.Sid, userName, &userNameSize, domainName, &domainNameSize, &sidUse )) { + getNameStatus = GetLastError(); + goto done; + } + + getNameStatus = ERROR_SUCCESS; + *allocatedName = userName; + userName = NULL; +done: + if (user != NULL) { + free( user ); + user = NULL; + } + if (userName != NULL) { + free( userName ); + userName = NULL; + } + if (domainName != NULL) { + free( domainName ); + domainName = NULL; + } + return getNameStatus; +} + +DWORD LoadUserProfileForLogon(__in HANDLE logonHandle, __out PROFILEINFO * pi) +{ + wchar_t *userName = NULL; + DWORD loadProfileStatus = ERROR_ASSERTION_FAILURE; // Failure to set status should trigger error + + loadProfileStatus = GetNameFromLogonToken( logonHandle, &userName ); + if (loadProfileStatus != ERROR_SUCCESS) { + goto done; + } + + assert(pi); + + ZeroMemory( pi, sizeof(*pi) ); + pi->dwSize = sizeof(*pi); + pi->lpUserName = userName; + pi->dwFlags = PI_NOUI; + + // if the profile does not exist it will be created + if ( !LoadUserProfile( logonHandle, pi ) ) { + loadProfileStatus = GetLastError(); + goto done; + } + + loadProfileStatus = ERROR_SUCCESS; +done: + return loadProfileStatus; +} + +DWORD UnloadProfileForLogon(__in HANDLE logonHandle, __in PROFILEINFO * pi) +{ + DWORD touchProfileStatus = ERROR_ASSERTION_FAILURE; // Failure to set status should trigger error + + assert(pi); + + if ( !UnloadUserProfile(logonHandle, pi->hProfile ) ) { + touchProfileStatus = GetLastError(); + goto done; + } + if (pi->lpUserName != NULL) { + free(pi->lpUserName); + pi->lpUserName = NULL; + } + ZeroMemory( pi, sizeof(*pi) ); + + touchProfileStatus = ERROR_SUCCESS; +done: + return touchProfileStatus; +} diff --git hadoop-common-project/hadoop-common/src/main/winutils/symlink.c hadoop-common-project/hadoop-common/src/main/winutils/symlink.c index ea372cc..02acd4d 100644 --- hadoop-common-project/hadoop-common/src/main/winutils/symlink.c +++ hadoop-common-project/hadoop-common/src/main/winutils/symlink.c @@ -77,7 +77,7 @@ int Symlink(__in int argc, __in_ecount(argc) wchar_t *argv[]) // This is just an additional step to do the privilege check by not using // error code from CreateSymbolicLink() method. // - if (!EnablePrivilege(L"SeCreateSymbolicLinkPrivilege")) + if (EnablePrivilege(L"SeCreateSymbolicLinkPrivilege") != ERROR_SUCCESS) { fwprintf(stderr, L"No privilege to create symbolic links.\n"); diff --git hadoop-common-project/hadoop-common/src/main/winutils/task.c hadoop-common-project/hadoop-common/src/main/winutils/task.c index 19bda96..81f2e4f 100644 --- hadoop-common-project/hadoop-common/src/main/winutils/task.c +++ hadoop-common-project/hadoop-common/src/main/winutils/task.c @@ -1,467 +1,738 @@ -/** -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with this -* work for additional information regarding copyright ownership. The ASF -* licenses this file to you under the Apache License, Version 2.0 (the -* "License"); you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -* License for the specific language governing permissions and limitations under -* the License. -*/ - -#include "winutils.h" -#include -#include - -#define PSAPI_VERSION 1 -#pragma comment(lib, "psapi.lib") - -#define ERROR_TASK_NOT_ALIVE 1 - -// This exit code for killed processes is compatible with Unix, where a killed -// process exits with 128 + signal. For SIGKILL, this would be 128 + 9 = 137. -#define KILLED_PROCESS_EXIT_CODE 137 - -// List of different task related command line options supported by -// winutils. -typedef enum TaskCommandOptionType -{ - TaskInvalid, - TaskCreate, - TaskIsAlive, - TaskKill, - TaskProcessList -} TaskCommandOption; - -//---------------------------------------------------------------------------- -// Function: ParseCommandLine -// -// Description: -// Parses the given command line. On success, out param 'command' contains -// the user specified command. -// -// Returns: -// TRUE: If the command line is valid -// FALSE: otherwise -static BOOL ParseCommandLine(__in int argc, - __in_ecount(argc) wchar_t *argv[], - __out TaskCommandOption *command) -{ - *command = TaskInvalid; - - if (wcscmp(argv[0], L"task") != 0 ) - { - return FALSE; - } - - if (argc == 3) { - if (wcscmp(argv[1], L"isAlive") == 0) - { - *command = TaskIsAlive; - return TRUE; - } - if (wcscmp(argv[1], L"kill") == 0) - { - *command = TaskKill; - return TRUE; - } - if (wcscmp(argv[1], L"processList") == 0) - { - *command = TaskProcessList; - return TRUE; - } - } - - if (argc == 4) { - if (wcscmp(argv[1], L"create") == 0) - { - *command = TaskCreate; - return TRUE; - } - } - - return FALSE; -} - -//---------------------------------------------------------------------------- -// Function: createTask -// -// Description: -// Creates a task via a jobobject. Outputs the -// appropriate information to stdout on success, or stderr on failure. -// -// Returns: -// ERROR_SUCCESS: On success -// GetLastError: otherwise -DWORD createTask(__in PCWSTR jobObjName,__in PWSTR cmdLine) -{ - DWORD err = ERROR_SUCCESS; - DWORD exitCode = EXIT_FAILURE; - STARTUPINFO si; - PROCESS_INFORMATION pi; - HANDLE jobObject = NULL; - JOBOBJECT_EXTENDED_LIMIT_INFORMATION jeli = { 0 }; - - // Create un-inheritable job object handle and set job object to terminate - // when last handle is closed. So winutils.exe invocation has the only open - // job object handle. Exit of winutils.exe ensures termination of job object. - // Either a clean exit of winutils or crash or external termination. - jobObject = CreateJobObject(NULL, jobObjName); - err = GetLastError(); - if(jobObject == NULL || err == ERROR_ALREADY_EXISTS) - { - return err; - } - jeli.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; - if(SetInformationJobObject(jobObject, - JobObjectExtendedLimitInformation, - &jeli, - sizeof(jeli)) == 0) - { - err = GetLastError(); - CloseHandle(jobObject); - return err; - } - - if(AssignProcessToJobObject(jobObject, GetCurrentProcess()) == 0) - { - err = GetLastError(); - CloseHandle(jobObject); - return err; - } - - // the child JVM uses this env var to send the task OS process identifier - // to the TaskTracker. We pass the job object name. - if(SetEnvironmentVariable(L"JVM_PID", jobObjName) == 0) - { - err = GetLastError(); - CloseHandle(jobObject); - return err; - } - - ZeroMemory( &si, sizeof(si) ); - si.cb = sizeof(si); - ZeroMemory( &pi, sizeof(pi) ); - - if (CreateProcess(NULL, cmdLine, NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi) == 0) - { - err = GetLastError(); - CloseHandle(jobObject); - return err; - } - - CloseHandle(pi.hThread); - - // Wait until child process exits. - WaitForSingleObject( pi.hProcess, INFINITE ); - if(GetExitCodeProcess(pi.hProcess, &exitCode) == 0) - { - err = GetLastError(); - } - CloseHandle( pi.hProcess ); - - // Terminate job object so that all spawned processes are also killed. - // This is needed because once this process closes the handle to the job - // object and none of the spawned objects have the handle open (via - // inheritance on creation) then it will not be possible for any other external - // program (say winutils task kill) to terminate this job object via its name. - if(TerminateJobObject(jobObject, exitCode) == 0) - { - err = GetLastError(); - } - - // comes here only on failure or TerminateJobObject - CloseHandle(jobObject); - - if(err != ERROR_SUCCESS) - { - return err; - } - return exitCode; -} - -//---------------------------------------------------------------------------- -// Function: isTaskAlive -// -// Description: -// Checks if a task is alive via a jobobject. Outputs the -// appropriate information to stdout on success, or stderr on failure. -// -// Returns: -// ERROR_SUCCESS: On success -// GetLastError: otherwise -DWORD isTaskAlive(const WCHAR* jobObjName, int* isAlive, int* procsInJob) -{ - PJOBOBJECT_BASIC_PROCESS_ID_LIST procList; - HANDLE jobObject = NULL; - int numProcs = 100; - - *isAlive = FALSE; - - jobObject = OpenJobObject(JOB_OBJECT_QUERY, FALSE, jobObjName); - - if(jobObject == NULL) - { - DWORD err = GetLastError(); - if(err == ERROR_FILE_NOT_FOUND) - { - // job object does not exist. assume its not alive - return ERROR_SUCCESS; - } - return err; - } - - procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); - if (!procList) - { - DWORD err = GetLastError(); - CloseHandle(jobObject); - return err; - } - if(QueryInformationJobObject(jobObject, JobObjectBasicProcessIdList, procList, sizeof(JOBOBJECT_BASIC_PROCESS_ID_LIST)+numProcs*32, NULL) == 0) - { - DWORD err = GetLastError(); - if(err != ERROR_MORE_DATA) - { - CloseHandle(jobObject); - LocalFree(procList); - return err; - } - } - - if(procList->NumberOfAssignedProcesses > 0) - { - *isAlive = TRUE; - *procsInJob = procList->NumberOfAssignedProcesses; - } - - LocalFree(procList); - - return ERROR_SUCCESS; -} - -//---------------------------------------------------------------------------- -// Function: killTask -// -// Description: -// Kills a task via a jobobject. Outputs the -// appropriate information to stdout on success, or stderr on failure. -// -// Returns: -// ERROR_SUCCESS: On success -// GetLastError: otherwise -DWORD killTask(PCWSTR jobObjName) -{ - HANDLE jobObject = OpenJobObject(JOB_OBJECT_TERMINATE, FALSE, jobObjName); - if(jobObject == NULL) - { - DWORD err = GetLastError(); - if(err == ERROR_FILE_NOT_FOUND) - { - // job object does not exist. assume its not alive - return ERROR_SUCCESS; - } - return err; - } - - if(TerminateJobObject(jobObject, KILLED_PROCESS_EXIT_CODE) == 0) - { - return GetLastError(); - } - CloseHandle(jobObject); - - return ERROR_SUCCESS; -} - -//---------------------------------------------------------------------------- -// Function: printTaskProcessList -// -// Description: -// Prints resource usage of all processes in the task jobobject -// -// Returns: -// ERROR_SUCCESS: On success -// GetLastError: otherwise -DWORD printTaskProcessList(const WCHAR* jobObjName) -{ - DWORD i; - PJOBOBJECT_BASIC_PROCESS_ID_LIST procList; - int numProcs = 100; - HANDLE jobObject = OpenJobObject(JOB_OBJECT_QUERY, FALSE, jobObjName); - if(jobObject == NULL) - { - DWORD err = GetLastError(); - return err; - } - - procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); - if (!procList) - { - DWORD err = GetLastError(); - CloseHandle(jobObject); - return err; - } - while(QueryInformationJobObject(jobObject, JobObjectBasicProcessIdList, procList, sizeof(JOBOBJECT_BASIC_PROCESS_ID_LIST)+numProcs*32, NULL) == 0) - { - DWORD err = GetLastError(); - if(err != ERROR_MORE_DATA) - { - CloseHandle(jobObject); - LocalFree(procList); - return err; - } - numProcs = procList->NumberOfAssignedProcesses; - LocalFree(procList); - procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); - if (procList == NULL) - { - err = GetLastError(); - CloseHandle(jobObject); - return err; - } - } - - for(i=0; iNumberOfProcessIdsInList; ++i) - { - HANDLE hProcess = OpenProcess( PROCESS_QUERY_INFORMATION, FALSE, (DWORD)procList->ProcessIdList[i] ); - if( hProcess != NULL ) - { - PROCESS_MEMORY_COUNTERS_EX pmc; - if ( GetProcessMemoryInfo( hProcess, (PPROCESS_MEMORY_COUNTERS)&pmc, sizeof(pmc)) ) - { - FILETIME create, exit, kernel, user; - if( GetProcessTimes( hProcess, &create, &exit, &kernel, &user) ) - { - ULARGE_INTEGER kernelTime, userTime; - ULONGLONG cpuTimeMs; - kernelTime.HighPart = kernel.dwHighDateTime; - kernelTime.LowPart = kernel.dwLowDateTime; - userTime.HighPart = user.dwHighDateTime; - userTime.LowPart = user.dwLowDateTime; - cpuTimeMs = (kernelTime.QuadPart+userTime.QuadPart)/10000; - fwprintf_s(stdout, L"%Iu,%Iu,%Iu,%I64u\n", procList->ProcessIdList[i], pmc.PrivateUsage, pmc.WorkingSetSize, cpuTimeMs); - } - } - CloseHandle( hProcess ); - } - } - - LocalFree(procList); - CloseHandle(jobObject); - - return ERROR_SUCCESS; -} - -//---------------------------------------------------------------------------- -// Function: Task -// -// Description: -// Manages a task via a jobobject (create/isAlive/kill). Outputs the -// appropriate information to stdout on success, or stderr on failure. -// -// Returns: -// ERROR_SUCCESS: On success -// Error code otherwise: otherwise -int Task(__in int argc, __in_ecount(argc) wchar_t *argv[]) -{ - DWORD dwErrorCode = ERROR_SUCCESS; - TaskCommandOption command = TaskInvalid; - - if (!ParseCommandLine(argc, argv, &command)) { - dwErrorCode = ERROR_INVALID_COMMAND_LINE; - - fwprintf(stderr, L"Incorrect command line arguments.\n\n"); - TaskUsage(); - goto TaskExit; - } - - if (command == TaskCreate) - { - // Create the task jobobject - // - dwErrorCode = createTask(argv[2], argv[3]); - if (dwErrorCode != ERROR_SUCCESS) - { - ReportErrorCode(L"createTask", dwErrorCode); - goto TaskExit; - } - } else if (command == TaskIsAlive) - { - // Check if task jobobject - // - int isAlive; - int numProcs; - dwErrorCode = isTaskAlive(argv[2], &isAlive, &numProcs); - if (dwErrorCode != ERROR_SUCCESS) - { - ReportErrorCode(L"isTaskAlive", dwErrorCode); - goto TaskExit; - } - - // Output the result - if(isAlive == TRUE) - { - fwprintf(stdout, L"IsAlive,%d\n", numProcs); - } - else - { - dwErrorCode = ERROR_TASK_NOT_ALIVE; - ReportErrorCode(L"isTaskAlive returned false", dwErrorCode); - goto TaskExit; - } - } else if (command == TaskKill) - { - // Check if task jobobject - // - dwErrorCode = killTask(argv[2]); - if (dwErrorCode != ERROR_SUCCESS) - { - ReportErrorCode(L"killTask", dwErrorCode); - goto TaskExit; - } - } else if (command == TaskProcessList) - { - // Check if task jobobject - // - dwErrorCode = printTaskProcessList(argv[2]); - if (dwErrorCode != ERROR_SUCCESS) - { - ReportErrorCode(L"printTaskProcessList", dwErrorCode); - goto TaskExit; - } - } else - { - // Should not happen - // - assert(FALSE); - } - -TaskExit: - return dwErrorCode; -} - -void TaskUsage() -{ - // Hadoop code checks for this string to determine if - // jobobject's are being used. - // ProcessTree.isSetsidSupported() - fwprintf(stdout, L"\ - Usage: task create [TASKNAME] [COMMAND_LINE] |\n\ - task isAlive [TASKNAME] |\n\ - task kill [TASKNAME]\n\ - task processList [TASKNAME]\n\ - Creates a new task jobobject with taskname\n\ - Checks if task jobobject is alive\n\ - Kills task jobobject\n\ - Prints to stdout a list of processes in the task\n\ - along with their resource usage. One process per line\n\ - and comma separated info per process\n\ - ProcessId,VirtualMemoryCommitted(bytes),\n\ - WorkingSetSize(bytes),CpuTime(Millisec,Kernel+User)\n"); -} +/** +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with this +* work for additional information regarding copyright ownership. The ASF +* licenses this file to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations under +* the License. +*/ + +#include "winutils.h" +#include +#include +#include + +#define PSAPI_VERSION 1 +#pragma comment(lib, "psapi.lib") + +#define ERROR_TASK_NOT_ALIVE 1 + +// This exit code for killed processes is compatible with Unix, where a killed +// process exits with 128 + signal. For SIGKILL, this would be 128 + 9 = 137. +#define KILLED_PROCESS_EXIT_CODE 137 + +// Name for tracking this logon process when registering with LSA +static const char *LOGON_PROCESS_NAME="Hadoop Container Executor"; +// Name for token source, must be less or eq to TOKEN_SOURCE_LENGTH (currently 8) chars +static const char *TOKEN_SOURCE_NAME = "HadoopEx"; + +// List of different task related command line options supported by +// winutils. +typedef enum TaskCommandOptionType +{ + TaskInvalid, + TaskCreate, + TaskCreateAsUser, + TaskIsAlive, + TaskKill, + TaskProcessList +} TaskCommandOption; + +//---------------------------------------------------------------------------- +// Function: ParseCommandLine +// +// Description: +// Parses the given command line. On success, out param 'command' contains +// the user specified command. +// +// Returns: +// TRUE: If the command line is valid +// FALSE: otherwise +static BOOL ParseCommandLine(__in int argc, + __in_ecount(argc) wchar_t *argv[], + __out TaskCommandOption *command) +{ + *command = TaskInvalid; + + if (wcscmp(argv[0], L"task") != 0 ) + { + return FALSE; + } + + if (argc == 3) { + if (wcscmp(argv[1], L"isAlive") == 0) + { + *command = TaskIsAlive; + return TRUE; + } + if (wcscmp(argv[1], L"kill") == 0) + { + *command = TaskKill; + return TRUE; + } + if (wcscmp(argv[1], L"processList") == 0) + { + *command = TaskProcessList; + return TRUE; + } + } + + if (argc == 4) { + if (wcscmp(argv[1], L"create") == 0) + { + *command = TaskCreate; + return TRUE; + } + } + + if (argc >= 6) { + if (wcscmp(argv[1], L"createAsUser") == 0) + { + *command = TaskCreateAsUser; + return TRUE; + } + } + + return FALSE; +} + +//---------------------------------------------------------------------------- +// Function: CreateTaskImpl +// +// Description: +// Creates a task via a jobobject. Outputs the +// appropriate information to stdout on success, or stderr on failure. +// logonHandle may be NULL, in this case the current logon will be utilized for the +// created process +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD CreateTaskImpl(__in_opt HANDLE logonHandle, __in PCWSTR jobObjName,__in PWSTR cmdLine) +{ + DWORD dwErrorCode = ERROR_SUCCESS; + DWORD exitCode = EXIT_FAILURE; + DWORD currDirCnt = 0; + STARTUPINFO si; + PROCESS_INFORMATION pi; + HANDLE jobObject = NULL; + JOBOBJECT_EXTENDED_LIMIT_INFORMATION jeli = { 0 }; + void * envBlock = NULL; + BOOL createProcessResult = FALSE; + + wchar_t* curr_dir = NULL; + FILE *stream = NULL; + + // Create un-inheritable job object handle and set job object to terminate + // when last handle is closed. So winutils.exe invocation has the only open + // job object handle. Exit of winutils.exe ensures termination of job object. + // Either a clean exit of winutils or crash or external termination. + jobObject = CreateJobObject(NULL, jobObjName); + dwErrorCode = GetLastError(); + if(jobObject == NULL || dwErrorCode == ERROR_ALREADY_EXISTS) + { + return dwErrorCode; + } + jeli.BasicLimitInformation.LimitFlags = JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE; + if(SetInformationJobObject(jobObject, + JobObjectExtendedLimitInformation, + &jeli, + sizeof(jeli)) == 0) + { + dwErrorCode = GetLastError(); + CloseHandle(jobObject); + return dwErrorCode; + } + + if(AssignProcessToJobObject(jobObject, GetCurrentProcess()) == 0) + { + dwErrorCode = GetLastError(); + CloseHandle(jobObject); + return dwErrorCode; + } + + // the child JVM uses this env var to send the task OS process identifier + // to the TaskTracker. We pass the job object name. + if(SetEnvironmentVariable(L"JVM_PID", jobObjName) == 0) + { + dwErrorCode = GetLastError(); + // We have to explictly Terminate, passing in the error code + // simply closing the job would kill our own process with success exit status + TerminateJobObject(jobObject, dwErrorCode); + return dwErrorCode; + } + + ZeroMemory( &si, sizeof(si) ); + si.cb = sizeof(si); + ZeroMemory( &pi, sizeof(pi) ); + + if( logonHandle != NULL ) { + // create user environment for this logon + if(!CreateEnvironmentBlock(&envBlock, + logonHandle, + FALSE )) { + dwErrorCode = GetLastError(); + // We have to explictly Terminate, passing in the error code + // simply closing the job would kill our own process with success exit status + TerminateJobObject(jobObject, dwErrorCode); + return dwErrorCode; + } + } + + // Get the required buffer size first + currDirCnt = GetCurrentDirectory(0, NULL); + if (0 < currDirCnt) { + curr_dir = (wchar_t*) alloca(currDirCnt * sizeof(wchar_t)); + assert(curr_dir); + currDirCnt = GetCurrentDirectory(currDirCnt, curr_dir); + } + + if (0 == currDirCnt) { + dwErrorCode = GetLastError(); + // We have to explictly Terminate, passing in the error code + // simply closing the job would kill our own process with success exit status + TerminateJobObject(jobObject, dwErrorCode); + return dwErrorCode; + } + + if (logonHandle == NULL) { + createProcessResult = CreateProcess( + NULL, // ApplicationName + cmdLine, // command line + NULL, // process security attributes + NULL, // thread security attributes + TRUE, // inherit handles + 0, // creation flags + NULL, // environment + curr_dir, // current directory + &si, // startup info + &pi); // process info + } + else { + createProcessResult = CreateProcessAsUser( + logonHandle, // logon token handle + NULL, // Application handle + cmdLine, // command line + NULL, // process security attributes + NULL, // thread security attributes + FALSE, // inherit handles + CREATE_UNICODE_ENVIRONMENT, // creation flags + envBlock, // environment + curr_dir, // current directory + &si, // startup info + &pi); // process info + } + + if (FALSE == createProcessResult) { + dwErrorCode = GetLastError(); + if( envBlock != NULL ) { + DestroyEnvironmentBlock( envBlock ); + envBlock = NULL; + } + // We have to explictly Terminate, passing in the error code + // simply closing the job would kill our own process with success exit status + TerminateJobObject(jobObject, dwErrorCode); + + // This is tehnically dead code, we cannot reach this condition + return dwErrorCode; + } + + CloseHandle(pi.hThread); + + // Wait until child process exits. + WaitForSingleObject( pi.hProcess, INFINITE ); + if(GetExitCodeProcess(pi.hProcess, &exitCode) == 0) + { + dwErrorCode = GetLastError(); + } + CloseHandle( pi.hProcess ); + + if( envBlock != NULL ) { + DestroyEnvironmentBlock( envBlock ); + envBlock = NULL; + } + + // Terminate job object so that all spawned processes are also killed. + // This is needed because once this process closes the handle to the job + // object and none of the spawned objects have the handle open (via + // inheritance on creation) then it will not be possible for any other external + // program (say winutils task kill) to terminate this job object via its name. + if(TerminateJobObject(jobObject, exitCode) == 0) + { + dwErrorCode = GetLastError(); + } + + // comes here only on failure of TerminateJobObject + CloseHandle(jobObject); + + if(dwErrorCode != ERROR_SUCCESS) + { + return dwErrorCode; + } + return exitCode; +} + +//---------------------------------------------------------------------------- +// Function: CreateTask +// +// Description: +// Creates a task via a jobobject. Outputs the +// appropriate information to stdout on success, or stderr on failure. +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD CreateTask(__in PCWSTR jobObjName,__in PWSTR cmdLine) +{ + // call with null logon in order to create tasks utilizing the current logon + return CreateTaskImpl( NULL, jobObjName, cmdLine ); +} +//---------------------------------------------------------------------------- +// Function: CreateTask +// +// Description: +// Creates a task via a jobobject. Outputs the +// appropriate information to stdout on success, or stderr on failure. +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD CreateTaskAsUser(__in PCWSTR jobObjName,__in PWSTR user, __in PWSTR pidFilePath, __in PWSTR cmdLine) +{ + DWORD err = ERROR_SUCCESS; + DWORD exitCode = EXIT_FAILURE; + ULONG authnPkgId; + HANDLE lsaHandle = INVALID_HANDLE_VALUE; + PROFILEINFO pi; + BOOL profileIsLoaded = FALSE; + FILE* pidFile = NULL; + + DWORD retLen = 0; + HANDLE logonHandle = NULL; + + err = EnablePrivilege(SE_TCB_NAME); + if( err != ERROR_SUCCESS ) { + fwprintf(stdout, L"INFO: The user does not have SE_TCB_NAME.\n"); + goto done; + } + err = EnablePrivilege(SE_ASSIGNPRIMARYTOKEN_NAME); + if( err != ERROR_SUCCESS ) { + fwprintf(stdout, L"INFO: The user does not have SE_ASSIGNPRIMARYTOKEN_NAME.\n"); + goto done; + } + err = EnablePrivilege(SE_INCREASE_QUOTA_NAME); + if( err != ERROR_SUCCESS ) { + fwprintf(stdout, L"INFO: The user does not have SE_INCREASE_QUOTA_NAME.\n"); + goto done; + } + err = EnablePrivilege(SE_RESTORE_NAME); + if( err != ERROR_SUCCESS ) { + fwprintf(stdout, L"INFO: The user does not have SE_RESTORE_NAME.\n"); + goto done; + } + + err = RegisterWithLsa(LOGON_PROCESS_NAME ,&lsaHandle); + if( err != ERROR_SUCCESS ) goto done; + + err = LookupKerberosAuthenticationPackageId( lsaHandle, &authnPkgId ); + if( err != ERROR_SUCCESS ) goto done; + + err = CreateLogonForUser(lsaHandle, + LOGON_PROCESS_NAME, + TOKEN_SOURCE_NAME, + authnPkgId, + user, + &logonHandle); + if( err != ERROR_SUCCESS ) goto done; + + err = LoadUserProfileForLogon(logonHandle, &pi); + if( err != ERROR_SUCCESS ) goto done; + profileIsLoaded = TRUE; + + // Create the PID file + + if (!(pidFile = _wfopen(pidFilePath, "w"))) { + err = GetLastError(); + goto done; + } + + if (0 > fprintf_s(pidFile, "%ls", jobObjName)) { + err = GetLastError(); + } + + fclose(pidFile); + + if (err != ERROR_SUCCESS) { + goto done; + } + + err = CreateTaskImpl(logonHandle, jobObjName, cmdLine); + +done: + if( profileIsLoaded ) { + UnloadProfileForLogon( logonHandle, &pi ); + profileIsLoaded = FALSE; + } + if( logonHandle != NULL ) { + CloseHandle(logonHandle); + } + + if (INVALID_HANDLE_VALUE != lsaHandle) { + UnregisterWithLsa(lsaHandle); + } + + return err; +} + + +//---------------------------------------------------------------------------- +// Function: IsTaskAlive +// +// Description: +// Checks if a task is alive via a jobobject. Outputs the +// appropriate information to stdout on success, or stderr on failure. +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD IsTaskAlive(const WCHAR* jobObjName, int* isAlive, int* procsInJob) +{ + PJOBOBJECT_BASIC_PROCESS_ID_LIST procList; + HANDLE jobObject = NULL; + int numProcs = 100; + + *isAlive = FALSE; + + jobObject = OpenJobObject(JOB_OBJECT_QUERY, FALSE, jobObjName); + + if(jobObject == NULL) + { + DWORD err = GetLastError(); + if(err == ERROR_FILE_NOT_FOUND) + { + // job object does not exist. assume its not alive + return ERROR_SUCCESS; + } + return err; + } + + procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); + if (!procList) + { + DWORD err = GetLastError(); + CloseHandle(jobObject); + return err; + } + if(QueryInformationJobObject(jobObject, JobObjectBasicProcessIdList, procList, sizeof(JOBOBJECT_BASIC_PROCESS_ID_LIST)+numProcs*32, NULL) == 0) + { + DWORD err = GetLastError(); + if(err != ERROR_MORE_DATA) + { + CloseHandle(jobObject); + LocalFree(procList); + return err; + } + } + + if(procList->NumberOfAssignedProcesses > 0) + { + *isAlive = TRUE; + *procsInJob = procList->NumberOfAssignedProcesses; + } + + LocalFree(procList); + + return ERROR_SUCCESS; +} + +//---------------------------------------------------------------------------- +// Function: KillTask +// +// Description: +// Kills a task via a jobobject. Outputs the +// appropriate information to stdout on success, or stderr on failure. +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD KillTask(PCWSTR jobObjName) +{ + HANDLE jobObject = OpenJobObject(JOB_OBJECT_TERMINATE, FALSE, jobObjName); + if(jobObject == NULL) + { + DWORD err = GetLastError(); + if(err == ERROR_FILE_NOT_FOUND) + { + // job object does not exist. assume its not alive + return ERROR_SUCCESS; + } + return err; + } + + if(TerminateJobObject(jobObject, KILLED_PROCESS_EXIT_CODE) == 0) + { + return GetLastError(); + } + CloseHandle(jobObject); + + return ERROR_SUCCESS; +} + +//---------------------------------------------------------------------------- +// Function: PrintTaskProcessList +// +// Description: +// Prints resource usage of all processes in the task jobobject +// +// Returns: +// ERROR_SUCCESS: On success +// GetLastError: otherwise +DWORD PrintTaskProcessList(const WCHAR* jobObjName) +{ + DWORD i; + PJOBOBJECT_BASIC_PROCESS_ID_LIST procList; + int numProcs = 100; + HANDLE jobObject = OpenJobObject(JOB_OBJECT_QUERY, FALSE, jobObjName); + if(jobObject == NULL) + { + DWORD err = GetLastError(); + return err; + } + + procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); + if (!procList) + { + DWORD err = GetLastError(); + CloseHandle(jobObject); + return err; + } + while(QueryInformationJobObject(jobObject, JobObjectBasicProcessIdList, procList, sizeof(JOBOBJECT_BASIC_PROCESS_ID_LIST)+numProcs*32, NULL) == 0) + { + DWORD err = GetLastError(); + if(err != ERROR_MORE_DATA) + { + CloseHandle(jobObject); + LocalFree(procList); + return err; + } + numProcs = procList->NumberOfAssignedProcesses; + LocalFree(procList); + procList = (PJOBOBJECT_BASIC_PROCESS_ID_LIST) LocalAlloc(LPTR, sizeof (JOBOBJECT_BASIC_PROCESS_ID_LIST) + numProcs*32); + if (procList == NULL) + { + err = GetLastError(); + CloseHandle(jobObject); + return err; + } + } + + for(i=0; iNumberOfProcessIdsInList; ++i) + { + HANDLE hProcess = OpenProcess( PROCESS_QUERY_INFORMATION, FALSE, (DWORD)procList->ProcessIdList[i] ); + if( hProcess != NULL ) + { + PROCESS_MEMORY_COUNTERS_EX pmc; + if ( GetProcessMemoryInfo( hProcess, (PPROCESS_MEMORY_COUNTERS)&pmc, sizeof(pmc)) ) + { + FILETIME create, exit, kernel, user; + if( GetProcessTimes( hProcess, &create, &exit, &kernel, &user) ) + { + ULARGE_INTEGER kernelTime, userTime; + ULONGLONG cpuTimeMs; + kernelTime.HighPart = kernel.dwHighDateTime; + kernelTime.LowPart = kernel.dwLowDateTime; + userTime.HighPart = user.dwHighDateTime; + userTime.LowPart = user.dwLowDateTime; + cpuTimeMs = (kernelTime.QuadPart+userTime.QuadPart)/10000; + fwprintf_s(stdout, L"%Iu,%Iu,%Iu,%I64u\n", procList->ProcessIdList[i], pmc.PrivateUsage, pmc.WorkingSetSize, cpuTimeMs); + } + } + CloseHandle( hProcess ); + } + } + + LocalFree(procList); + CloseHandle(jobObject); + + return ERROR_SUCCESS; +} + +//---------------------------------------------------------------------------- +// Function: Task +// +// Description: +// Manages a task via a jobobject (create/isAlive/kill). Outputs the +// appropriate information to stdout on success, or stderr on failure. +// +// Returns: +// ERROR_SUCCESS: On success +// Error code otherwise: otherwise +int Task(__in int argc, __in_ecount(argc) wchar_t *argv[]) +{ + DWORD dwErrorCode = ERROR_SUCCESS; + TaskCommandOption command = TaskInvalid; + wchar_t* cmdLine = NULL; + wchar_t buffer[16*1024] = L""; // 32K max command line + size_t charCountBufferLeft = sizeof (buffer)/sizeof(wchar_t); + int crtArgIndex = 0; + size_t argLen = 0; + size_t wscatErr = 0; + wchar_t* insertHere = NULL; + + enum { + ARGC_JOBOBJECTNAME = 2, + ARGC_USERNAME, + ARGC_PIDFILE, + ARGC_COMMAND, + ARGC_COMMAND_ARGS + }; + + if (!ParseCommandLine(argc, argv, &command)) { + dwErrorCode = ERROR_INVALID_COMMAND_LINE; + + fwprintf(stderr, L"Incorrect command line arguments.\n\n"); + TaskUsage(); + goto TaskExit; + } + + if (command == TaskCreate) + { + // Create the task jobobject + // + dwErrorCode = CreateTask(argv[2], argv[3]); + if (dwErrorCode != ERROR_SUCCESS) + { + ReportErrorCode(L"CreateTask", dwErrorCode); + goto TaskExit; + } + } else if (command == TaskCreateAsUser) + { + // Create the task jobobject as a domain user + // createAsUser accepts an open list of arguments. All arguments after the command are + // to be passed as argumrnts to the command itself.Here we're concatenating all + // arguments after the command into a single arg entry. + // + cmdLine = argv[ARGC_COMMAND]; + if (argc > ARGC_COMMAND_ARGS) { + crtArgIndex = ARGC_COMMAND; + insertHere = buffer; + while (crtArgIndex < argc) { + argLen = wcslen(argv[crtArgIndex]); + wscatErr = wcscat_s(insertHere, charCountBufferLeft, argv[crtArgIndex]); + switch (wscatErr) { + case 0: + // 0 means success; + break; + case EINVAL: + dwErrorCode = ERROR_INVALID_PARAMETER; + goto TaskExit; + case ERANGE: + dwErrorCode = ERROR_INSUFFICIENT_BUFFER; + goto TaskExit; + default: + // This case is not MSDN documented. + dwErrorCode = ERROR_GEN_FAILURE; + goto TaskExit; + } + insertHere += argLen; + charCountBufferLeft -= argLen; + insertHere[0] = L' '; + insertHere += 1; + charCountBufferLeft -= 1; + insertHere[0] = 0; + ++crtArgIndex; + } + cmdLine = buffer; + } + + dwErrorCode = CreateTaskAsUser( + argv[ARGC_JOBOBJECTNAME], argv[ARGC_USERNAME], argv[ARGC_PIDFILE], cmdLine); + if (dwErrorCode != ERROR_SUCCESS) + { + ReportErrorCode(L"CreateTaskAsUser", dwErrorCode); + goto TaskExit; + } + } else if (command == TaskIsAlive) + { + // Check if task jobobject + // + int isAlive; + int numProcs; + dwErrorCode = IsTaskAlive(argv[2], &isAlive, &numProcs); + if (dwErrorCode != ERROR_SUCCESS) + { + ReportErrorCode(L"IsTaskAlive", dwErrorCode); + goto TaskExit; + } + + // Output the result + if(isAlive == TRUE) + { + fwprintf(stdout, L"IsAlive,%d\n", numProcs); + } + else + { + dwErrorCode = ERROR_TASK_NOT_ALIVE; + ReportErrorCode(L"IsTaskAlive returned false", dwErrorCode); + goto TaskExit; + } + } else if (command == TaskKill) + { + // Check if task jobobject + // + dwErrorCode = KillTask(argv[2]); + if (dwErrorCode != ERROR_SUCCESS) + { + ReportErrorCode(L"KillTask", dwErrorCode); + goto TaskExit; + } + } else if (command == TaskProcessList) + { + // Check if task jobobject + // + dwErrorCode = PrintTaskProcessList(argv[2]); + if (dwErrorCode != ERROR_SUCCESS) + { + ReportErrorCode(L"PrintTaskProcessList", dwErrorCode); + goto TaskExit; + } + } else + { + // Should not happen + // + assert(FALSE); + } + +TaskExit: + return dwErrorCode; +} + +void TaskUsage() +{ + // Hadoop code checks for this string to determine if + // jobobject's are being used. + // ProcessTree.isSetsidSupported() + fwprintf(stdout, L"\ + Usage: task create [TASKNAME] [COMMAND_LINE] |\n\ + task createAsUser [TASKNAME] [USERNAME] [PIDFILE] [COMMAND_LINE] |\n\ + task isAlive [TASKNAME] |\n\ + task kill [TASKNAME]\n\ + task processList [TASKNAME]\n\ + Creates a new task jobobject with taskname\n\ + Creates a new task jobobject with taskname as the user provided\n\ + Checks if task jobobject is alive\n\ + Kills task jobobject\n\ + Prints to stdout a list of processes in the task\n\ + along with their resource usage. One process per line\n\ + and comma separated info per process\n\ + ProcessId,VirtualMemoryCommitted(bytes),\n\ + WorkingSetSize(bytes),CpuTime(Millisec,Kernel+User)\n"); +} diff --git hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestWinUtils.java hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestWinUtils.java index 588b217..53d34c5 100644 --- hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestWinUtils.java +++ hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/util/TestWinUtils.java @@ -20,10 +20,12 @@ import static org.junit.Assert.*; import static org.junit.Assume.assumeTrue; +import static org.junit.matchers.JUnitMatchers.containsString; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; +import java.io.FileWriter; import java.io.IOException; import org.apache.commons.io.FileUtils; @@ -33,7 +35,7 @@ import org.junit.After; import org.junit.Before; import org.junit.Test; -import static org.junit.Assume.*; + import static org.hamcrest.CoreMatchers.*; /** @@ -521,4 +523,25 @@ public void testReadLink() throws IOException { assertThat(ece.getExitCode(), is(1)); } } + + @Test(timeout=10000) + public void testTaskCreate() throws IOException { + File batch = new File(TEST_DIR, "testTaskCreate.cmd"); + File proof = new File(TEST_DIR, "testTaskCreate.out"); + FileWriter fw = new FileWriter(batch); + String testNumber = String.format("%f", Math.random()); + fw.write(String.format("echo %s > \"%s\"", testNumber, proof.getAbsolutePath())); + fw.close(); + + assertFalse(proof.exists()); + + Shell.execCommand(Shell.WINUTILS, "task", "create", "testTaskCreate" + testNumber, + batch.getAbsolutePath()); + + assertTrue(proof.exists()); + + String outNumber = FileUtils.readFileToString(proof); + + assertThat(outNumber, containsString(testNumber)); + } }