diff --git a/pkg/cim/smb.go b/pkg/cim/smb.go new file mode 100644 index 00000000..5868d456 --- /dev/null +++ b/pkg/cim/smb.go @@ -0,0 +1,53 @@ +//go:build windows +// +build windows + +package cim + +import ( + "github.com/microsoft/wmi/pkg/base/query" + cim "github.com/microsoft/wmi/pkg/wmiinstance" +) + +// Refer to https://learn.microsoft.com/en-us/previous-versions/windows/desktop/smb/msft-smbmapping +const ( + SmbMappingStatusOK int32 = iota + SmbMappingStatusPaused + SmbMappingStatusDisconnected + SmbMappingStatusNetworkError + SmbMappingStatusConnecting + SmbMappingStatusReconnecting + SmbMappingStatusUnavailable +) + +// QuerySmbGlobalMappingByRemotePath retrieves the SMB global mapping from its remote path. +// +// The equivalent WMI query is: +// +// SELECT [selectors] FROM MSFT_SmbGlobalMapping +// +// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping +// for the WMI class definition. +func QuerySmbGlobalMappingByRemotePath(remotePath string) (*cim.WmiInstance, error) { + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + instances, err := QueryInstances(WMINamespaceSmb, smbQuery) + if err != nil { + return nil, err + } + + return instances[0], err +} + +// RemoveSmbGlobalMappingByRemotePath removes a SMB global mapping matching to the remote path. +// +// Refer to https://pkg.go.dev/github.com/microsoft/wmi/server2019/root/microsoft/windows/smb#MSFT_SmbGlobalMapping +// for the WMI class definition. +func RemoveSmbGlobalMappingByRemotePath(remotePath string) error { + smbQuery := query.NewWmiQuery("MSFT_SmbGlobalMapping", "RemotePath", remotePath) + instances, err := QueryInstances(WMINamespaceSmb, smbQuery) + if err != nil { + return err + } + + _, err = instances[0].InvokeMethod("Remove", true) + return err +} diff --git a/pkg/cim/wmi.go b/pkg/cim/wmi.go index 3e0375ba..81e17701 100644 --- a/pkg/cim/wmi.go +++ b/pkg/cim/wmi.go @@ -18,6 +18,7 @@ import ( const ( WMINamespaceRoot = "Root\\CimV2" WMINamespaceStorage = "Root\\Microsoft\\Windows\\Storage" + WMINamespaceSmb = "Root\\Microsoft\\Windows\\Smb" ) type InstanceHandler func(instance *cim.WmiInstance) (bool, error) diff --git a/pkg/os/smb/api.go b/pkg/os/smb/api.go index 910eb7e0..20b9544e 100644 --- a/pkg/os/smb/api.go +++ b/pkg/os/smb/api.go @@ -3,8 +3,15 @@ package smb import ( "fmt" "strings" + "syscall" + "github.com/kubernetes-csi/csi-proxy/pkg/cim" "github.com/kubernetes-csi/csi-proxy/pkg/utils" + "golang.org/x/sys/windows" +) + +const ( + credentialDelimiter = ":" ) type API interface { @@ -26,18 +33,52 @@ func New(requirePrivacy bool) *SmbAPI { } } +func remotePathForQuery(remotePath string) string { + return strings.ReplaceAll(remotePath, "\\", "\\\\") +} + +func escapeUserName(userName string) string { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L169-L170 + escaped := strings.ReplaceAll(userName, "\\", "\\\\") + escaped = strings.ReplaceAll(escaped, credentialDelimiter, "\\"+credentialDelimiter) + return escaped +} + +func createSymlink(link, target string, isDir bool) error { + linkPtr, err := syscall.UTF16PtrFromString(link) + if err != nil { + return err + } + targetPtr, err := syscall.UTF16PtrFromString(target) + if err != nil { + return err + } + + var flags uint32 + if isDir { + flags = windows.SYMBOLIC_LINK_FLAG_DIRECTORY + } + + err = windows.CreateSymbolicLink( + linkPtr, + targetPtr, + flags, + ) + return err +} + func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { - cmdLine := `$(Get-SmbGlobalMapping -RemotePath $Env:smbremotepath -ErrorAction Stop).Status ` - cmdEnv := fmt.Sprintf("smbremotepath=%s", remotePath) - out, err := utils.RunPowershellCmd(cmdLine, cmdEnv) + inst, err := cim.QuerySmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) if err != nil { - return false, fmt.Errorf("error checking smb mapping. cmd %s, output: %s, err: %v", remotePath, string(out), err) + return false, cim.IgnoreNotFound(err) } - if len(out) == 0 || !strings.EqualFold(strings.TrimSpace(string(out)), "OK") { - return false, nil + status, err := inst.GetProperty("Status") + if err != nil { + return false, err } - return true, nil + + return status.(int32) == cim.SmbMappingStatusOK, nil } // NewSmbLink - creates a directory symbolic link to the remote share. @@ -48,42 +89,46 @@ func (*SmbAPI) IsSmbMapped(remotePath string) (bool, error) { // alpha to merge the paths. // TODO (for beta release): Merge the link paths - os.Symlink and Powershell link path. func (*SmbAPI) NewSmbLink(remotePath, localPath string) error { - if !strings.HasSuffix(remotePath, "\\") { // Golang has issues resolving paths mapped to file shares if they do not end in a trailing \ // so add one if needed. remotePath = remotePath + "\\" } + longRemotePath := utils.EnsureLongPath(remotePath) + longLocalPath := utils.EnsureLongPath(localPath) - cmdLine := `New-Item -ItemType SymbolicLink $Env:smblocalPath -Target $Env:smbremotepath` - output, err := utils.RunPowershellCmd(cmdLine, fmt.Sprintf("smbremotepath=%s", remotePath), fmt.Sprintf("smblocalpath=%s", localPath)) + err := createSymlink(longLocalPath, longRemotePath, true) if err != nil { - return fmt.Errorf("error linking %s to %s. output: %s, err: %v", remotePath, localPath, string(output), err) + return fmt.Errorf("error linking %s to %s. err: %v", remotePath, localPath, err) } return nil } func (api *SmbAPI) NewSmbGlobalMapping(remotePath, username, password string) error { - // use PowerShell Environment Variables to store user input string to prevent command line injection - // https://docs.microsoft.com/en-us/powershell/module/microsoft.powershell.core/about/about_environment_variables?view=powershell-5.1 - cmdLine := fmt.Sprintf(`$PWord = ConvertTo-SecureString -String $Env:smbpassword -AsPlainText -Force`+ - `;$Credential = New-Object -TypeName System.Management.Automation.PSCredential -ArgumentList $Env:smbuser, $PWord`+ - `;New-SmbGlobalMapping -RemotePath $Env:smbremotepath -Credential $Credential -RequirePrivacy $%t`, api.RequirePrivacy) - - if output, err := utils.RunPowershellCmd(cmdLine, - fmt.Sprintf("smbuser=%s", username), - fmt.Sprintf("smbpassword=%s", password), - fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil { - return fmt.Errorf("NewSmbGlobalMapping failed. output: %q, err: %v", string(output), err) + params := map[string]interface{}{ + "RemotePath": remotePath, + "RequirePrivacy": api.RequirePrivacy, + } + if username != "" { + // refer to https://github.com/PowerShell/PowerShell/blob/9303de597da55963a6e26a8fe164d0b256ca3d4d/src/Microsoft.PowerShell.Commands.Management/cimSupport/cmdletization/cim/cimConverter.cs#L166-L178 + // on how SMB credential is handled in PowerShell + params["Credential"] = escapeUserName(username) + credentialDelimiter + password } + + result, _, err := cim.InvokeCimMethod(cim.WMINamespaceSmb, "MSFT_SmbGlobalMapping", "Create", params) + if err != nil { + return fmt.Errorf("NewSmbGlobalMapping failed. result: %d, err: %v", result, err) + } + return nil } func (*SmbAPI) RemoveSmbGlobalMapping(remotePath string) error { - cmd := `Remove-SmbGlobalMapping -RemotePath $Env:smbremotepath -Force` - if output, err := utils.RunPowershellCmd(cmd, fmt.Sprintf("smbremotepath=%s", remotePath)); err != nil { - return fmt.Errorf("UnmountSmbShare failed. output: %q, err: %v", string(output), err) + err := cim.RemoveSmbGlobalMappingByRemotePath(remotePathForQuery(remotePath)) + if err != nil { + return fmt.Errorf("error remove smb mapping '%s'. err: %v", remotePath, err) } + return nil } diff --git a/pkg/os/volume/api.go b/pkg/os/volume/api.go index 6394ab28..5bdf0e04 100644 --- a/pkg/os/volume/api.go +++ b/pkg/os/volume/api.go @@ -10,6 +10,7 @@ import ( "github.com/go-ole/go-ole" "github.com/kubernetes-csi/csi-proxy/pkg/cim" + "github.com/kubernetes-csi/csi-proxy/pkg/utils" wmierrors "github.com/microsoft/wmi/pkg/errors" "github.com/pkg/errors" "golang.org/x/sys/windows" @@ -57,8 +58,6 @@ var ( // PS C:\disks> (Get-Disk -Number 1 | Get-Partition | Get-Volume).UniqueId // \\?\Volume{452e318a-5cde-421e-9831-b9853c521012}\ VolumeRegexp = regexp.MustCompile(`Volume\{[\w-]*\}`) - // longPathPrefix is the prefix of Windows long path - longPathPrefix = "\\\\?\\" notMountedFolder = errors.New("not a mounted folder") ) @@ -337,7 +336,7 @@ func getTarget(mount string) (string, error) { if err != nil { return "", err } - targetPath := longPathPrefix + windows.UTF16PtrToString(&outPathBuffer[0]) + targetPath := utils.EnsureLongPath(windows.UTF16PtrToString(&outPathBuffer[0])) if !strings.HasSuffix(targetPath, "\\") { targetPath += "\\" } diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 87c50339..102675ac 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -3,11 +3,24 @@ package utils import ( "os" "os/exec" + "strings" "k8s.io/klog/v2" ) -const MaxPathLengthWindows = 260 +const ( + MaxPathLengthWindows = 260 + + // LongPathPrefix is the prefix of Windows long path + LongPathPrefix = `\\?\` +) + +func EnsureLongPath(path string) string { + if !strings.HasPrefix(path, LongPathPrefix) { + path = LongPathPrefix + path + } + return path +} func RunPowershellCmd(command string, envs ...string) ([]byte, error) { cmd := exec.Command("powershell", "-Mta", "-NoProfile", "-Command", command)