File size: 1,259 Bytes
a75ff23 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
#!/usr/bin/env python3
import argparse
import torch
import transformers
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_path", type=str, help="Input directory")
parser.add_argument("output_path", type=str, help="Output directory")
args = parser.parse_args()
robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True)
unk_id, mask_id, new_vocab = 3, 51960, 51997
assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight
assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias
for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: #, robeczech.lm_head.decoder.weight]:
original = weight.data
assert original.shape[0] == mask_id + 1, original.shape
weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype)
weight.data[:mask_id + 1] = original
for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)):
weight.data[new_unk] = original[unk_id]
robeczech.save_pretrained(args.output_path)
robeczech.save_pretrained(args.output_path, safe_serialization=False)
|