Skip to content

AR-Diffusion Cannot Inference with the pretrained Commongen checkpoint #86

@chuanky

Description

@chuanky

Dear Author,

I downloaded the commongen checkpoint commongen_checkpoint-40000 from the google drive.

And run the script in gen.sh

# Commongen
FILE_NAME=commongen
STEP=40000
MODEL_NAME=models/bert-base-uncased

python ./gen_utils/generate.py \
model.name=$MODEL_NAME batch_size=800 \
exp.name=$FILE_NAME load_step=$STEP \
data.name=commongen max_pos_len=128 num_samples=50 \
intermediate_size=512 num_attention_heads=8 \
in_channels=64 out_channels=64 time_channels=64 \
skip_sample=True gen_timesteps=20 \
schedule_sampler='xy_uniform' time_att=True att_strategy='txl' \
tgt_len=54 prediction=True \

It shows that FileNotFoundError: [Errno 2] No such file or directory: './my_output/commongen/commongen/model/model_checkpoint-40000'

So I copy the downloaded checkpoint to './my_output/commongen/commongen/model/ and rename it as model_checkpoint-40000.

After doing this, the console shows that

RuntimeError: Error(s) in loading state_dict for CrossAttention_Diffusion_LM:
        Missing key(s) in state_dict: "position_ids", "position_embeddings.weight". 
        Unexpected key(s) in state_dict: "embed_positions.weight". 

I'm not sure whether I did something wrong or what should I do next.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions