|
|
|
import argparse |
|
import torch |
|
import transformers |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("input_path", type=str, help="Input directory") |
|
parser.add_argument("output_path", type=str, help="Output directory") |
|
args = parser.parse_args() |
|
|
|
robeczech = transformers.AutoModelForMaskedLM.from_pretrained(args.input_path, add_pooling_layer=True) |
|
|
|
unk_id, mask_id, new_vocab = 3, 51960, 51997 |
|
|
|
assert robeczech.roberta.embeddings.word_embeddings.weight is robeczech.lm_head.decoder.weight |
|
assert robeczech.lm_head.bias is robeczech.lm_head.decoder.bias |
|
for weight in [robeczech.roberta.embeddings.word_embeddings.weight, robeczech.lm_head.bias]: |
|
original = weight.data |
|
assert original.shape[0] == mask_id + 1, original.shape |
|
weight.data = torch.zeros((new_vocab,) + original.shape[1:], dtype=original.dtype) |
|
weight.data[:mask_id + 1] = original |
|
for new_unk in [mask_id - 1] + list(range(mask_id + 1, new_vocab)): |
|
weight.data[new_unk] = original[unk_id] |
|
|
|
robeczech.save_pretrained(args.output_path) |
|
robeczech.save_pretrained(args.output_path, safe_serialization=False) |
|
|