@@ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
26
26
params.n_threads = std::max (1 , (int32_t ) std::thread::hardware_concurrency ());
27
27
}
28
28
29
+ bool invalid_param = false ;
30
+ std::string arg;
29
31
for (int i = 1 ; i < argc; i++) {
30
- std::string arg = argv[i];
32
+ arg = argv[i];
31
33
32
34
if (arg == " -s" || arg == " --seed" ) {
33
- params.seed = std::stoi (argv[++i]);
35
+ if (++i >= argc) {
36
+ invalid_param = true ;
37
+ break ;
38
+ }
39
+ params.seed = std::stoi (argv[i]);
34
40
} else if (arg == " -t" || arg == " --threads" ) {
35
- params.n_threads = std::stoi (argv[++i]);
41
+ if (++i >= argc) {
42
+ invalid_param = true ;
43
+ break ;
44
+ }
45
+ params.n_threads = std::stoi (argv[i]);
36
46
} else if (arg == " -p" || arg == " --prompt" ) {
37
- params.prompt = argv[++i];
47
+ if (++i >= argc) {
48
+ invalid_param = true ;
49
+ break ;
50
+ }
51
+ params.prompt = argv[i];
38
52
} else if (arg == " -f" || arg == " --file" ) {
39
- std::ifstream file (argv[++i]);
53
+ if (++i >= argc) {
54
+ invalid_param = true ;
55
+ break ;
56
+ }
57
+ std::ifstream file (argv[i]);
40
58
std::copy (std::istreambuf_iterator<char >(file), std::istreambuf_iterator<char >(), back_inserter (params.prompt ));
41
59
if (params.prompt .back () == ' \n ' ) {
42
60
params.prompt .pop_back ();
43
61
}
44
62
} else if (arg == " -n" || arg == " --n_predict" ) {
45
- params.n_predict = std::stoi (argv[++i]);
63
+ if (++i >= argc) {
64
+ invalid_param = true ;
65
+ break ;
66
+ }
67
+ params.n_predict = std::stoi (argv[i]);
46
68
} else if (arg == " --top_k" ) {
47
- params.top_k = std::stoi (argv[++i]);
69
+ if (++i >= argc) {
70
+ invalid_param = true ;
71
+ break ;
72
+ }
73
+ params.top_k = std::stoi (argv[i]);
48
74
} else if (arg == " -c" || arg == " --ctx_size" ) {
49
- params.n_ctx = std::stoi (argv[++i]);
75
+ if (++i >= argc) {
76
+ invalid_param = true ;
77
+ break ;
78
+ }
79
+ params.n_ctx = std::stoi (argv[i]);
50
80
} else if (arg == " --memory_f16" ) {
51
81
params.memory_f16 = true ;
52
82
} else if (arg == " --top_p" ) {
53
- params.top_p = std::stof (argv[++i]);
83
+ if (++i >= argc) {
84
+ invalid_param = true ;
85
+ break ;
86
+ }
87
+ params.top_p = std::stof (argv[i]);
54
88
} else if (arg == " --temp" ) {
55
- params.temp = std::stof (argv[++i]);
89
+ if (++i >= argc) {
90
+ invalid_param = true ;
91
+ break ;
92
+ }
93
+ params.temp = std::stof (argv[i]);
56
94
} else if (arg == " --repeat_last_n" ) {
57
- params.repeat_last_n = std::stoi (argv[++i]);
95
+ if (++i >= argc) {
96
+ invalid_param = true ;
97
+ break ;
98
+ }
99
+ params.repeat_last_n = std::stoi (argv[i]);
58
100
} else if (arg == " --repeat_penalty" ) {
59
- params.repeat_penalty = std::stof (argv[++i]);
101
+ if (++i >= argc) {
102
+ invalid_param = true ;
103
+ break ;
104
+ }
105
+ params.repeat_penalty = std::stof (argv[i]);
60
106
} else if (arg == " -b" || arg == " --batch_size" ) {
61
- params.n_batch = std::stoi (argv[++i]);
107
+ if (++i >= argc) {
108
+ invalid_param = true ;
109
+ break ;
110
+ }
111
+ params.n_batch = std::stoi (argv[i]);
62
112
} else if (arg == " -m" || arg == " --model" ) {
63
- params.model = argv[++i];
113
+ if (++i >= argc) {
114
+ invalid_param = true ;
115
+ break ;
116
+ }
117
+ params.model = argv[i];
64
118
} else if (arg == " -i" || arg == " --interactive" ) {
65
119
params.interactive = true ;
66
120
} else if (arg == " --interactive-first" ) {
@@ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
70
124
} else if (arg == " --color" ) {
71
125
params.use_color = true ;
72
126
} else if (arg == " -r" || arg == " --reverse-prompt" ) {
73
- params.antiprompt .push_back (argv[++i]);
127
+ if (++i >= argc) {
128
+ invalid_param = true ;
129
+ break ;
130
+ }
131
+ params.antiprompt .push_back (argv[i]);
74
132
} else if (arg == " --perplexity" ) {
75
133
params.perplexity = true ;
76
134
} else if (arg == " --ignore-eos" ) {
77
135
params.ignore_eos = true ;
78
136
} else if (arg == " --n_parts" ) {
79
- params.n_parts = std::stoi (argv[++i]);
137
+ if (++i >= argc) {
138
+ invalid_param = true ;
139
+ break ;
140
+ }
141
+ params.n_parts = std::stoi (argv[i]);
80
142
} else if (arg == " -h" || arg == " --help" ) {
81
143
gpt_print_usage (argc, argv, params);
82
144
exit (0 );
@@ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
85
147
} else {
86
148
fprintf (stderr, " error: unknown argument: %s\n " , arg.c_str ());
87
149
gpt_print_usage (argc, argv, params);
88
- exit (0 );
150
+ exit (1 );
89
151
}
90
152
}
153
+ if (invalid_param) {
154
+ fprintf (stderr, " error: invalid parameter for argument: %s\n " , arg.c_str ());
155
+ gpt_print_usage (argc, argv, params);
156
+ exit (1 );
157
+ }
91
158
92
159
return true ;
93
160
}
0 commit comments