|
23 | 23 | from itertools import chain
|
24 | 24 | from os.path import dirname, isfile
|
25 | 25 | from pathlib import Path
|
26 |
| -from typing import Dict, List, Optional, Sequence, Tuple |
| 26 | +from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, Union |
27 | 27 |
|
28 |
| -from pkg_resources import parse_requirements |
| 28 | +from pkg_resources import parse_requirements, Requirement, yield_lines |
29 | 29 |
|
30 | 30 | REQUIREMENT_FILES = {
|
31 | 31 | "pytorch": (
|
|
49 | 49 | _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__))
|
50 | 50 |
|
51 | 51 |
|
52 |
| -def _augment_requirement(ln: str, comment_char: str = "#", unfreeze: str = "all") -> str: |
53 |
| - """Adjust the upper version contrains. |
54 |
| -
|
55 |
| - Args: |
56 |
| - ln: raw line from requirement |
57 |
| - comment_char: charter marking comment |
58 |
| - unfreeze: Enum or "all"|"major"|"" |
59 |
| -
|
60 |
| - Returns: |
61 |
| - adjusted requirement |
62 |
| -
|
63 |
| - >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # anything", unfreeze="none") |
64 |
| - 'arrow<=1.2.2,>=1.2.0' |
65 |
| - >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # strict", unfreeze="none") |
66 |
| - 'arrow<=1.2.2,>=1.2.0 # strict' |
67 |
| - >>> _augment_requirement("arrow<=1.2.2,>=1.2.0 # my name", unfreeze="all") |
68 |
| - 'arrow>=1.2.0' |
69 |
| - >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="all") |
70 |
| - 'arrow>=1.2.0, <=1.2.2 # strict' |
71 |
| - >>> _augment_requirement("arrow", unfreeze="all") |
72 |
| - 'arrow' |
73 |
| - >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # cool", unfreeze="major") |
74 |
| - 'arrow>=1.2.0, <2.0 # strict' |
75 |
| - >>> _augment_requirement("arrow>=1.2.0, <=1.2.2 # strict", unfreeze="major") |
76 |
| - 'arrow>=1.2.0, <=1.2.2 # strict' |
77 |
| - >>> _augment_requirement("arrow>=1.2.0", unfreeze="major") |
78 |
| - 'arrow>=1.2.0, <2.0 # strict' |
79 |
| - >>> _augment_requirement("arrow", unfreeze="major") |
80 |
| - 'arrow' |
| 52 | +class _RequirementWithComment(Requirement): |
| 53 | + strict_string = "# strict" |
| 54 | + |
| 55 | + def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: |
| 56 | + super().__init__(*args, **kwargs) |
| 57 | + self.comment = comment |
| 58 | + assert pip_argument is None or pip_argument # sanity check that it's not an empty str |
| 59 | + self.pip_argument = pip_argument |
| 60 | + self.strict = self.strict_string in comment.lower() |
| 61 | + |
| 62 | + def adjust(self, unfreeze: str) -> str: |
| 63 | + """Remove version restrictions unless they are strict. |
| 64 | +
|
| 65 | + >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# anything").adjust("none") |
| 66 | + 'arrow<=1.2.2,>=1.2.0' |
| 67 | + >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# strict").adjust("none") |
| 68 | + 'arrow<=1.2.2,>=1.2.0 # strict' |
| 69 | + >>> _RequirementWithComment("arrow<=1.2.2,>=1.2.0", comment="# my name").adjust("all") |
| 70 | + 'arrow>=1.2.0' |
| 71 | + >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("all") |
| 72 | + 'arrow<=1.2.2,>=1.2.0 # strict' |
| 73 | + >>> _RequirementWithComment("arrow").adjust("all") |
| 74 | + 'arrow' |
| 75 | + >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# cool").adjust("major") |
| 76 | + 'arrow<2.0,>=1.2.0' |
| 77 | + >>> _RequirementWithComment("arrow>=1.2.0, <=1.2.2", comment="# strict").adjust("major") |
| 78 | + 'arrow<=1.2.2,>=1.2.0 # strict' |
| 79 | + >>> _RequirementWithComment("arrow>=1.2.0").adjust("major") |
| 80 | + 'arrow>=1.2.0' |
| 81 | + >>> _RequirementWithComment("arrow").adjust("major") |
| 82 | + 'arrow' |
| 83 | + """ |
| 84 | + out = str(self) |
| 85 | + if self.strict: |
| 86 | + return f"{out} {self.strict_string}" |
| 87 | + if unfreeze == "major": |
| 88 | + for operator, version in self.specs: |
| 89 | + if operator in ("<", "<="): |
| 90 | + major = LooseVersion(version).version[0] |
| 91 | + # replace upper bound with major version increased by one |
| 92 | + return out.replace(f"{operator}{version}", f"<{major + 1}.0") |
| 93 | + elif unfreeze == "all": |
| 94 | + for operator, version in self.specs: |
| 95 | + if operator in ("<", "<="): |
| 96 | + # drop upper bound |
| 97 | + return out.replace(f"{operator}{version},", "") |
| 98 | + elif unfreeze != "none": |
| 99 | + raise ValueError(f"Unexpected unfreeze: {unfreeze!r} value.") |
| 100 | + return out |
| 101 | + |
| 102 | + |
| 103 | +def _parse_requirements(strs: Union[str, Iterable[str]]) -> Iterator[_RequirementWithComment]: |
| 104 | + """Adapted from `pkg_resources.parse_requirements` to include comments. |
| 105 | +
|
| 106 | + >>> txt = ['# ignored', '', 'this # is an', '--piparg', 'example', 'foo # strict', 'thing', '-r different/file.txt'] |
| 107 | + >>> [r.adjust('none') for r in _parse_requirements(txt)] |
| 108 | + ['this', 'example', 'foo # strict', 'thing'] |
| 109 | + >>> txt = '\\n'.join(txt) |
| 110 | + >>> [r.adjust('none') for r in _parse_requirements(txt)] |
| 111 | + ['this', 'example', 'foo # strict', 'thing'] |
81 | 112 | """
|
82 |
| - assert unfreeze in {"none", "major", "all"} |
83 |
| - # filer all comments |
84 |
| - if comment_char in ln: |
85 |
| - comment = ln[ln.index(comment_char) :] |
86 |
| - ln = ln[: ln.index(comment_char)] |
87 |
| - is_strict = "strict" in comment |
88 |
| - else: |
89 |
| - is_strict = False |
90 |
| - req = ln.strip() |
91 |
| - # skip directly installed dependencies |
92 |
| - if not req or any(c in req for c in ["http:", "https:", "@"]): |
93 |
| - return "" |
94 |
| - # extract the major version from all listed versions |
95 |
| - if unfreeze == "major": |
96 |
| - req_ = list(parse_requirements([req]))[0] |
97 |
| - vers = [LooseVersion(v) for s, v in req_.specs if s not in ("==", "~=")] |
98 |
| - ver_major = sorted(vers)[-1].version[0] if vers else None |
99 |
| - else: |
100 |
| - ver_major = None |
101 |
| - |
102 |
| - # remove version restrictions unless they are strict |
103 |
| - if unfreeze != "none" and "<" in req and not is_strict: |
104 |
| - req = re.sub(r",? *<=? *[\d\.\*]+,? *", "", req).strip() |
105 |
| - if ver_major is not None and not is_strict: |
106 |
| - # add , only if there are already some versions |
107 |
| - req += f"{',' if any(c in req for c in '<=>') else ''} <{int(ver_major) + 1}.0" |
108 |
| - |
109 |
| - # adding strict back to the comment |
110 |
| - if is_strict or ver_major is not None: |
111 |
| - req += " # strict" |
112 |
| - |
113 |
| - return req |
114 |
| - |
115 |
| - |
116 |
| -def load_requirements( |
117 |
| - path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: str = "all" |
118 |
| -) -> List[str]: |
| 113 | + lines = yield_lines(strs) |
| 114 | + pip_argument = None |
| 115 | + for line in lines: |
| 116 | + # Drop comments -- a hash without a space may be in a URL. |
| 117 | + if " #" in line: |
| 118 | + comment_pos = line.find(" #") |
| 119 | + line, comment = line[:comment_pos], line[comment_pos:] |
| 120 | + else: |
| 121 | + comment = "" |
| 122 | + # If there is a line continuation, drop it, and append the next line. |
| 123 | + if line.endswith("\\"): |
| 124 | + line = line[:-2].strip() |
| 125 | + try: |
| 126 | + line += next(lines) |
| 127 | + except StopIteration: |
| 128 | + return |
| 129 | + # If there's a pip argument, save it |
| 130 | + if line.startswith("--"): |
| 131 | + pip_argument = line |
| 132 | + continue |
| 133 | + if line.startswith("-r "): |
| 134 | + # linked requirement files are unsupported |
| 135 | + continue |
| 136 | + yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument) |
| 137 | + pip_argument = None |
| 138 | + |
| 139 | + |
| 140 | +def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: str = "all") -> List[str]: |
119 | 141 | """Loading requirements from a file.
|
120 | 142 |
|
121 | 143 | >>> path_req = os.path.join(_PROJECT_ROOT, "requirements")
|
122 | 144 | >>> load_requirements(path_req, "docs.txt", unfreeze="major") # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
|
123 |
| - ['sphinx>=4.0, <6.0 # strict', ...] |
| 145 | + ['sphinx<6.0,>=4.0', ...] |
124 | 146 | """
|
125 | 147 | assert unfreeze in {"none", "major", "all"}
|
126 |
| - with open(os.path.join(path_dir, file_name)) as file: |
127 |
| - lines = [ln.strip() for ln in file.readlines()] |
128 |
| - reqs = [_augment_requirement(ln, comment_char=comment_char, unfreeze=unfreeze) for ln in lines] |
129 |
| - # filter empty lines and containing @ which means redirect to some git/http |
130 |
| - reqs = [str(req) for req in reqs if req and not any(c in req for c in ["@", "http:", "https:"])] |
131 |
| - return reqs |
| 148 | + path = Path(path_dir) / file_name |
| 149 | + assert path.exists(), (path_dir, file_name, path) |
| 150 | + text = path.read_text() |
| 151 | + return [req.adjust(unfreeze) for req in _parse_requirements(text)] |
132 | 152 |
|
133 | 153 |
|
134 | 154 | def load_readme_description(path_dir: str, homepage: str, version: str) -> str:
|
@@ -213,14 +233,13 @@ def _load_aggregate_requirements(req_dir: str = "requirements", freeze_requireme
|
213 | 233 | >>> _load_aggregate_requirements(os.path.join(_PROJECT_ROOT, "requirements"))
|
214 | 234 | """
|
215 | 235 | requires = [
|
216 |
| - # TODO: consider passing unfreeze as string instead |
217 |
| - load_requirements(d, file_name="base.txt", unfreeze="none" if freeze_requirements else "major") |
| 236 | + load_requirements(d, unfreeze="none" if freeze_requirements else "major") |
218 | 237 | for d in glob.glob(os.path.join(req_dir, "*"))
|
219 | 238 | # skip empty folder as git artefacts, and resolving Will's special issue
|
220 | 239 | if os.path.isdir(d) and len(glob.glob(os.path.join(d, "*"))) > 0 and "__pycache__" not in d
|
221 | 240 | ]
|
222 | 241 | if not requires:
|
223 |
| - return None |
| 242 | + return |
224 | 243 | # TODO: add some smarter version aggregation per each package
|
225 | 244 | requires = sorted(set(chain(*requires)))
|
226 | 245 | with open(os.path.join(req_dir, "base.txt"), "w") as fp:
|
|
0 commit comments