|
18 | 18 |
|
19 | 19 | """
|
20 | 20 | import glob
|
| 21 | +import fnmatch |
21 | 22 | import string
|
22 | 23 | import os
|
23 | 24 | import os.path as op
|
24 | 25 | import shutil
|
| 26 | +import subprocess |
25 | 27 | import re
|
26 | 28 | import tempfile
|
27 | 29 | from warnings import warn
|
|
34 | 36 | except:
|
35 | 37 | pass
|
36 | 38 |
|
| 39 | +try: |
| 40 | + import paramiko |
| 41 | +except: |
| 42 | + pass |
| 43 | + |
37 | 44 | from nipype.interfaces.base import (TraitedSpec, traits, File, Directory,
|
38 | 45 | BaseInterface, InputMultiPath, isdefined,
|
39 | 46 | OutputMultiPath, DynamicTraitedSpec,
|
@@ -750,7 +757,7 @@ class DataFinder(IOBase):
|
750 | 757 | '013-ep2d_fid_T1_pre']
|
751 | 758 | >>> print result.outputs.basename # doctest: +SKIP
|
752 | 759 | ['acquisition',
|
753 |
| - 'acquisition', |
| 760 | + 'acquisition' |
754 | 761 | 'acquisition',
|
755 | 762 | 'acquisition']
|
756 | 763 |
|
@@ -1539,3 +1546,260 @@ def _list_outputs(self):
|
1539 | 1546 | conn.commit()
|
1540 | 1547 | c.close()
|
1541 | 1548 | return None
|
| 1549 | + |
| 1550 | +class SSHDataGrabberInputSpec(DataGrabberInputSpec): |
| 1551 | + hostname = traits.Str(mandatory=True, |
| 1552 | + desc='Server hostname.') |
| 1553 | + username = traits.Str(mandatory=False, |
| 1554 | + desc='Server username.') |
| 1555 | + password = traits.Password(mandatory=False, |
| 1556 | + desc='Server password.') |
| 1557 | + download_files = traits.Bool(True, usedefault=True, |
| 1558 | + desc='If false it will return the file names without downloading them') |
| 1559 | + base_directory = traits.Str(mandatory=True, |
| 1560 | + desc='Path to the base directory consisting of subject data.') |
| 1561 | + template_expression = traits.Enum(['fnmatch', 'regexp'], usedefault=True, |
| 1562 | + desc='Use either fnmatch or regexp to express templates') |
| 1563 | + ssh_log_to_file = traits.Str('', usedefault=True, |
| 1564 | + desc='If set SSH commands will be logged to the given file') |
| 1565 | + |
| 1566 | + |
| 1567 | +class SSHDataGrabber(DataGrabber): |
| 1568 | + """ Extension of DataGrabber module that downloads the file list and |
| 1569 | + optionally the files from a SSH server. The SSH operation must |
| 1570 | + not need user and password so an SSH agent must be active in |
| 1571 | + where this module is being run. |
| 1572 | +
|
| 1573 | +
|
| 1574 | + .. attention:: |
| 1575 | +
|
| 1576 | + Doesn't support directories currently |
| 1577 | +
|
| 1578 | + Examples |
| 1579 | + -------- |
| 1580 | +
|
| 1581 | + >>> from nipype.interfaces.io import SSHDataGrabber |
| 1582 | + >>> dg = SSHDataGrabber() |
| 1583 | + >>> dg.inputs.hostname = 'test.rebex.net' |
| 1584 | + >>> dg.inputs.user = 'demo' |
| 1585 | + >>> dg.inputs.password = 'password' |
| 1586 | + >>> dg.inputs.base_directory = 'pub/example' |
| 1587 | +
|
| 1588 | + Pick all files from the base directory |
| 1589 | +
|
| 1590 | + >>> dg.inputs.template = '*' |
| 1591 | +
|
| 1592 | + Pick all files starting with "s" and a number from current directory |
| 1593 | +
|
| 1594 | + >>> dg.inputs.template_expression = 'regexp' |
| 1595 | + >>> dg.inputs.template = 'pop[0-9].*' |
| 1596 | +
|
| 1597 | + Same thing but with dynamically created fields |
| 1598 | +
|
| 1599 | + >>> dg = SSHDataGrabber(infields=['arg1','arg2']) |
| 1600 | + >>> dg.inputs.hostname = 'test.rebex.net' |
| 1601 | + >>> dg.inputs.user = 'demo' |
| 1602 | + >>> dg.inputs.password = 'password' |
| 1603 | + >>> dg.inputs.base_directory = 'pub' |
| 1604 | + >>> dg.inputs.template = '%s/%s.txt' |
| 1605 | + >>> dg.inputs.arg1 = 'example' |
| 1606 | + >>> dg.inputs.arg2 = 'foo' |
| 1607 | +
|
| 1608 | + however this latter form can be used with iterables and iterfield in a |
| 1609 | + pipeline. |
| 1610 | +
|
| 1611 | + Dynamically created, user-defined input and output fields |
| 1612 | +
|
| 1613 | + >>> dg = SSHDataGrabber(infields=['sid'], outfields=['func','struct','ref']) |
| 1614 | + >>> dg.inputs.hostname = 'myhost.com' |
| 1615 | + >>> dg.inputs.base_directory = '/main_folder/my_remote_dir' |
| 1616 | + >>> dg.inputs.template_args['func'] = [['sid',['f3','f5']]] |
| 1617 | + >>> dg.inputs.template_args['struct'] = [['sid',['struct']]] |
| 1618 | + >>> dg.inputs.template_args['ref'] = [['sid','ref']] |
| 1619 | + >>> dg.inputs.sid = 's1' |
| 1620 | +
|
| 1621 | + Change the template only for output field struct. The rest use the |
| 1622 | + general template |
| 1623 | +
|
| 1624 | + >>> dg.inputs.field_template = dict(struct='%s/struct.nii') |
| 1625 | + >>> dg.inputs.template_args['struct'] = [['sid']] |
| 1626 | +
|
| 1627 | + """ |
| 1628 | + input_spec = SSHDataGrabberInputSpec |
| 1629 | + output_spec = DynamicTraitedSpec |
| 1630 | + _always_run = False |
| 1631 | + |
| 1632 | + def __init__(self, infields=None, outfields=None, **kwargs): |
| 1633 | + """ |
| 1634 | + Parameters |
| 1635 | + ---------- |
| 1636 | + infields : list of str |
| 1637 | + Indicates the input fields to be dynamically created |
| 1638 | +
|
| 1639 | + outfields: list of str |
| 1640 | + Indicates output fields to be dynamically created |
| 1641 | +
|
| 1642 | + See class examples for usage |
| 1643 | +
|
| 1644 | + """ |
| 1645 | + try: |
| 1646 | + paramiko |
| 1647 | + except NameError: |
| 1648 | + warn( |
| 1649 | + "The library parmiko needs to be installed" |
| 1650 | + " for this module to run." |
| 1651 | + ) |
| 1652 | + if not outfields: |
| 1653 | + outfields = ['outfiles'] |
| 1654 | + kwargs = kwargs.copy() |
| 1655 | + kwargs['infields'] = infields |
| 1656 | + kwargs['outfields'] = outfields |
| 1657 | + super(SSHDataGrabber, self).__init__(**kwargs) |
| 1658 | + if ( |
| 1659 | + None in (self.inputs.username, self.inputs.password) |
| 1660 | + ): |
| 1661 | + raise ValueError( |
| 1662 | + "either both username and password " |
| 1663 | + "are provided or none of them" |
| 1664 | + ) |
| 1665 | + |
| 1666 | + if ( |
| 1667 | + self.inputs.template_expression == 'regexp' and |
| 1668 | + self.inputs.template[-1] != '$' |
| 1669 | + ): |
| 1670 | + self.inputs.template += '$' |
| 1671 | + |
| 1672 | + |
| 1673 | + def _list_outputs(self): |
| 1674 | + try: |
| 1675 | + paramiko |
| 1676 | + except NameError: |
| 1677 | + raise ImportError( |
| 1678 | + "The library parmiko needs to be installed" |
| 1679 | + " for this module to run." |
| 1680 | + ) |
| 1681 | + |
| 1682 | + if len(self.inputs.ssh_log_to_file) > 0: |
| 1683 | + paramiko.util.log_to_file(self.inputs.ssh_log_to_file) |
| 1684 | + # infields are mandatory, however I could not figure out how to set 'mandatory' flag dynamically |
| 1685 | + # hence manual check |
| 1686 | + if self._infields: |
| 1687 | + for key in self._infields: |
| 1688 | + value = getattr(self.inputs, key) |
| 1689 | + if not isdefined(value): |
| 1690 | + msg = "%s requires a value for input '%s' because it was listed in 'infields'" % \ |
| 1691 | + (self.__class__.__name__, key) |
| 1692 | + raise ValueError(msg) |
| 1693 | + |
| 1694 | + outputs = {} |
| 1695 | + for key, args in self.inputs.template_args.items(): |
| 1696 | + outputs[key] = [] |
| 1697 | + template = self.inputs.template |
| 1698 | + if hasattr(self.inputs, 'field_template') and \ |
| 1699 | + isdefined(self.inputs.field_template) and \ |
| 1700 | + key in self.inputs.field_template: |
| 1701 | + template = self.inputs.field_template[key] |
| 1702 | + if not args: |
| 1703 | + client = self._get_ssh_client() |
| 1704 | + sftp = client.open_sftp() |
| 1705 | + sftp.chdir(self.inputs.base_directory) |
| 1706 | + filelist = sftp.listdir() |
| 1707 | + if self.inputs.template_expression == 'fnmatch': |
| 1708 | + filelist = fnmatch.filter(filelist, template) |
| 1709 | + elif self.inputs.template_expression == 'regexp': |
| 1710 | + regexp = re.compile(template) |
| 1711 | + filelist = filter(regexp.match, filelist) |
| 1712 | + else: |
| 1713 | + raise ValueError('template_expression value invalid') |
| 1714 | + if len(filelist) == 0: |
| 1715 | + msg = 'Output key: %s Template: %s returned no files' % ( |
| 1716 | + key, template) |
| 1717 | + if self.inputs.raise_on_empty: |
| 1718 | + raise IOError(msg) |
| 1719 | + else: |
| 1720 | + warn(msg) |
| 1721 | + else: |
| 1722 | + if self.inputs.sort_filelist: |
| 1723 | + filelist = human_order_sorted(filelist) |
| 1724 | + outputs[key] = list_to_filename(filelist) |
| 1725 | + if self.inputs.download_files: |
| 1726 | + for f in filelist: |
| 1727 | + sftp.get(f, f) |
| 1728 | + for argnum, arglist in enumerate(args): |
| 1729 | + maxlen = 1 |
| 1730 | + for arg in arglist: |
| 1731 | + if isinstance(arg, str) and hasattr(self.inputs, arg): |
| 1732 | + arg = getattr(self.inputs, arg) |
| 1733 | + if isinstance(arg, list): |
| 1734 | + if (maxlen > 1) and (len(arg) != maxlen): |
| 1735 | + raise ValueError('incompatible number of arguments for %s' % key) |
| 1736 | + if len(arg) > maxlen: |
| 1737 | + maxlen = len(arg) |
| 1738 | + outfiles = [] |
| 1739 | + for i in range(maxlen): |
| 1740 | + argtuple = [] |
| 1741 | + for arg in arglist: |
| 1742 | + if isinstance(arg, str) and hasattr(self.inputs, arg): |
| 1743 | + arg = getattr(self.inputs, arg) |
| 1744 | + if isinstance(arg, list): |
| 1745 | + argtuple.append(arg[i]) |
| 1746 | + else: |
| 1747 | + argtuple.append(arg) |
| 1748 | + filledtemplate = template |
| 1749 | + if argtuple: |
| 1750 | + try: |
| 1751 | + filledtemplate = template % tuple(argtuple) |
| 1752 | + except TypeError as e: |
| 1753 | + raise TypeError(e.message + ": Template %s failed to convert with args %s" % (template, str(tuple(argtuple)))) |
| 1754 | + client = self._get_ssh_client() |
| 1755 | + sftp = client.open_sftp() |
| 1756 | + sftp.chdir(self.inputs.base_directory) |
| 1757 | + filledtemplate_dir = os.path.dirname(filledtemplate) |
| 1758 | + filledtemplate_base = os.path.basename(filledtemplate) |
| 1759 | + filelist = sftp.listdir(filledtemplate_dir) |
| 1760 | + if self.inputs.template_expression == 'fnmatch': |
| 1761 | + outfiles = fnmatch.filter(filelist, filledtemplate_base) |
| 1762 | + elif self.inputs.template_expression == 'regexp': |
| 1763 | + regexp = re.compile(filledtemplate_base) |
| 1764 | + outfiles = filter(regexp.match, filelist) |
| 1765 | + else: |
| 1766 | + raise ValueError('template_expression value invalid') |
| 1767 | + if len(outfiles) == 0: |
| 1768 | + msg = 'Output key: %s Template: %s returned no files' % (key, filledtemplate) |
| 1769 | + if self.inputs.raise_on_empty: |
| 1770 | + raise IOError(msg) |
| 1771 | + else: |
| 1772 | + warn(msg) |
| 1773 | + outputs[key].append(None) |
| 1774 | + else: |
| 1775 | + if self.inputs.sort_filelist: |
| 1776 | + outfiles = human_order_sorted(outfiles) |
| 1777 | + outputs[key].append(list_to_filename(outfiles)) |
| 1778 | + if self.inputs.download_files: |
| 1779 | + for f in outfiles: |
| 1780 | + sftp.get(os.path.join(filledtemplate_dir, f), f) |
| 1781 | + if any([val is None for val in outputs[key]]): |
| 1782 | + outputs[key] = [] |
| 1783 | + if len(outputs[key]) == 0: |
| 1784 | + outputs[key] = None |
| 1785 | + elif len(outputs[key]) == 1: |
| 1786 | + outputs[key] = outputs[key][0] |
| 1787 | + return outputs |
| 1788 | + |
| 1789 | + def _get_ssh_client(self): |
| 1790 | + config = paramiko.SSHConfig() |
| 1791 | + config.parse(open(os.path.expanduser('~/.ssh/config'))) |
| 1792 | + host = config.lookup(self.inputs.hostname) |
| 1793 | + if 'proxycommand' in host: |
| 1794 | + proxy = paramiko.ProxyCommand( |
| 1795 | + subprocess.check_output( |
| 1796 | + [os.environ['SHELL'], '-c', 'echo %s' % host['proxycommand']] |
| 1797 | + ).strip() |
| 1798 | + ) |
| 1799 | + else: |
| 1800 | + proxy = None |
| 1801 | + client = paramiko.SSHClient() |
| 1802 | + client.load_system_host_keys() |
| 1803 | + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
| 1804 | + client.connect(host['hostname'], username=host['user'], sock=proxy) |
| 1805 | + return client |
0 commit comments