diff --git a/generate.py b/generate.py index ec31c42..733fe30 100644 --- a/generate.py +++ b/generate.py @@ -1,6 +1,4 @@ -""" -@uthor: Prakhar -""" + import os import argparse import torch @@ -23,7 +21,7 @@ def choose_from_top_k_top_n(probs, k=50, p=0.8): f.append(k) pr.append(v) if t>=p: - break + break top_prob = pr / np.sum(pr) token_id = np.random.choice(f, 1, p = top_prob)