File size: 1,259 Bytes
a75ff23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python3
import argparse
import torch
import transformers

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input_path", type=str, help="Input directory")
    parser.add_argument("output_path", type=str, help="Output directory")
    args = parser.parse_args()

    robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True)

    unk_id, mask_id, new_vocab = 3, 51960, 51997

    assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight
    assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias
    for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: #, robeczech.lm_head.decoder.weight]:
        original = weight.data
        assert original.shape[0] == mask_id + 1, original.shape
        weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype)
        weight.data[:mask_id + 1] = original
        for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)):
            weight.data[new_unk] = original[unk_id]

    robeczech.save_pretrained(args.output_path)
    robeczech.save_pretrained(args.output_path, safe_serialization=False)