Skip to content

Commit ea10d3d

Browse files
authored
Command line args bounds checking (#424)
* command line args bounds checking * unknown and invalid param exit codes 0 -> 1
1 parent a18c192 commit ea10d3d

File tree

1 file changed

+84
-17
lines changed

1 file changed

+84
-17
lines changed

utils.cpp

Lines changed: 84 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
2626
params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency());
2727
}
2828

29+
bool invalid_param = false;
30+
std::string arg;
2931
for (int i = 1; i < argc; i++) {
30-
std::string arg = argv[i];
32+
arg = argv[i];
3133

3234
if (arg == "-s" || arg == "--seed") {
33-
params.seed = std::stoi(argv[++i]);
35+
if (++i >= argc) {
36+
invalid_param = true;
37+
break;
38+
}
39+
params.seed = std::stoi(argv[i]);
3440
} else if (arg == "-t" || arg == "--threads") {
35-
params.n_threads = std::stoi(argv[++i]);
41+
if (++i >= argc) {
42+
invalid_param = true;
43+
break;
44+
}
45+
params.n_threads = std::stoi(argv[i]);
3646
} else if (arg == "-p" || arg == "--prompt") {
37-
params.prompt = argv[++i];
47+
if (++i >= argc) {
48+
invalid_param = true;
49+
break;
50+
}
51+
params.prompt = argv[i];
3852
} else if (arg == "-f" || arg == "--file") {
39-
std::ifstream file(argv[++i]);
53+
if (++i >= argc) {
54+
invalid_param = true;
55+
break;
56+
}
57+
std::ifstream file(argv[i]);
4058
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
4159
if (params.prompt.back() == '\n') {
4260
params.prompt.pop_back();
4361
}
4462
} else if (arg == "-n" || arg == "--n_predict") {
45-
params.n_predict = std::stoi(argv[++i]);
63+
if (++i >= argc) {
64+
invalid_param = true;
65+
break;
66+
}
67+
params.n_predict = std::stoi(argv[i]);
4668
} else if (arg == "--top_k") {
47-
params.top_k = std::stoi(argv[++i]);
69+
if (++i >= argc) {
70+
invalid_param = true;
71+
break;
72+
}
73+
params.top_k = std::stoi(argv[i]);
4874
} else if (arg == "-c" || arg == "--ctx_size") {
49-
params.n_ctx = std::stoi(argv[++i]);
75+
if (++i >= argc) {
76+
invalid_param = true;
77+
break;
78+
}
79+
params.n_ctx = std::stoi(argv[i]);
5080
} else if (arg == "--memory_f16") {
5181
params.memory_f16 = true;
5282
} else if (arg == "--top_p") {
53-
params.top_p = std::stof(argv[++i]);
83+
if (++i >= argc) {
84+
invalid_param = true;
85+
break;
86+
}
87+
params.top_p = std::stof(argv[i]);
5488
} else if (arg == "--temp") {
55-
params.temp = std::stof(argv[++i]);
89+
if (++i >= argc) {
90+
invalid_param = true;
91+
break;
92+
}
93+
params.temp = std::stof(argv[i]);
5694
} else if (arg == "--repeat_last_n") {
57-
params.repeat_last_n = std::stoi(argv[++i]);
95+
if (++i >= argc) {
96+
invalid_param = true;
97+
break;
98+
}
99+
params.repeat_last_n = std::stoi(argv[i]);
58100
} else if (arg == "--repeat_penalty") {
59-
params.repeat_penalty = std::stof(argv[++i]);
101+
if (++i >= argc) {
102+
invalid_param = true;
103+
break;
104+
}
105+
params.repeat_penalty = std::stof(argv[i]);
60106
} else if (arg == "-b" || arg == "--batch_size") {
61-
params.n_batch = std::stoi(argv[++i]);
107+
if (++i >= argc) {
108+
invalid_param = true;
109+
break;
110+
}
111+
params.n_batch = std::stoi(argv[i]);
62112
} else if (arg == "-m" || arg == "--model") {
63-
params.model = argv[++i];
113+
if (++i >= argc) {
114+
invalid_param = true;
115+
break;
116+
}
117+
params.model = argv[i];
64118
} else if (arg == "-i" || arg == "--interactive") {
65119
params.interactive = true;
66120
} else if (arg == "--interactive-first") {
@@ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
70124
} else if (arg == "--color") {
71125
params.use_color = true;
72126
} else if (arg == "-r" || arg == "--reverse-prompt") {
73-
params.antiprompt.push_back(argv[++i]);
127+
if (++i >= argc) {
128+
invalid_param = true;
129+
break;
130+
}
131+
params.antiprompt.push_back(argv[i]);
74132
} else if (arg == "--perplexity") {
75133
params.perplexity = true;
76134
} else if (arg == "--ignore-eos") {
77135
params.ignore_eos = true;
78136
} else if (arg == "--n_parts") {
79-
params.n_parts = std::stoi(argv[++i]);
137+
if (++i >= argc) {
138+
invalid_param = true;
139+
break;
140+
}
141+
params.n_parts = std::stoi(argv[i]);
80142
} else if (arg == "-h" || arg == "--help") {
81143
gpt_print_usage(argc, argv, params);
82144
exit(0);
@@ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
85147
} else {
86148
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
87149
gpt_print_usage(argc, argv, params);
88-
exit(0);
150+
exit(1);
89151
}
90152
}
153+
if (invalid_param) {
154+
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
155+
gpt_print_usage(argc, argv, params);
156+
exit(1);
157+
}
91158

92159
return true;
93160
}

0 commit comments

Comments
 (0)