#!/usr/bin/env python3 import argparse import torch import transformers if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("input_path", type=str, help="Input directory") parser.add_argument("output_path", type=str, help="Output directory") args = parser.parse_args() robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True) unk_id, mask_id, new_vocab = 3, 51960, 51997 assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: #, robeczech.lm_head.decoder.weight]: original = weight.data assert original.shape[0] == mask_id + 1, original.shape weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype) weight.data[:mask_id + 1] = original for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)): weight.data[new_unk] = original[unk_id] robeczech.save_pretrained(args.output_path) robeczech.save_pretrained(args.output_path, safe_serialization=False)