{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# pip -q install sentencepiece\n", "# pip -q install numpy\n", "# pip -q install sentence_transformers\n", "# pip -q install datasets\n", "import sentencepiece as spm\n", "import numpy as np\n", "from datasets import load_dataset\n", "from collections import Counter\n", "from sentence_transformers import SentenceTransformer" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "model = SentenceTransformer('all-MiniLM-L6-v2')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[\"My favourite food is anything I didn't have to cook myself.\", 'Now if he does off himself, everyone will think hes having a laugh screwing with people instead of actually dead', 'WHY THE FUCK IS BAYLESS ISOING', 'To make her feel threatened', 'Dirty Southern Wankers', \"OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe PlAyOfFs! Dumbass Broncos fans circa December 2015.\", 'Yes I heard abt the f bombs! That has to be why. Thanks for your reply:) until then hubby and I will anxiously wait 😝', 'We need more boards and to create a bit more space for [NAME]. Then we’ll be good.', 'Damn youtube and outrage drama is super lucrative for reddit', 'It might be linked to the trust factor of your friend.']\n" ] } ], "source": [ "dataset = load_dataset(\"go_emotions\")\n", "texts = dataset[\"train\"][\"text\"]\n", "print(texts[:10])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def binarize(embeddings, sensitivity=0.1):\n", "\treturn np.where(embeddings >= sensitivity, 1, 0)\n", "\n", "def preprocess(strings):\n", "\treturn \"\\n\".join([\"\".join(map(str, s)) for s in processed_string])\n", "\n", "# Obtain sentence embeddings\n", "embeddings = model.encode(texts)\n", "binary_hashes = binarize(embeddings)\n", "binary_string = preprocess(binary_hashes)\n", "print(binary_string[:500])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save passage to a temporary file\n", "with open(\"passage.txt\", \"w\") as f:\n", "\tf.write(binary_string)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Training options documentation: https://github.com/google/sentencepiece/blob/master/doc/options.md\n", "# Training takes 3 hours to complete on GTX 1650 mobile\n", "spm.SentencePieceTrainer.train(\n", "\tinput='passage.txt',\n", "\tmodel_prefix='384_bit_comp',\n", "\tvocab_size=256 + 3, # To exclude , , \n", "\tcharacter_coverage=1.00,\n", "\tmax_sentencepiece_length=384,\n", "\tmodel_type='unigram',\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "length: 13\n", "encoded_tokens: ['▁0000000', '0000000000000001000000000000000000000', '00000000001000100', '1000000', '00000000000000000000000000000001000000000000000000000000000000000000000000000000000000', '00000000000000000001000000000000000000000000000000000', '0000000000000000000000000000000001000', '00000000000000000000000100000000000000000', '00000000010', '0000000000000000000000000000000000000100', '00000000000100000000000000000', '00000000010', '00001000']\n", "encoded_ids: 1ab2ed09d7a9617206894e0608\n", "same?: True\n", "count: Counter({'00000000010': 2, '▁0000000': 1, '0000000000000001000000000000000000000': 1, '00000000001000100': 1, '1000000': 1, '00000000000000000000000000000001000000000000000000000000000000000000000000000000000000': 1, '00000000000000000001000000000000000000000000000000000': 1, '0000000000000000000000000000000001000': 1, '00000000000000000000000100000000000000000': 1, '0000000000000000000000000000000000000100': 1, '00000000000100000000000000000': 1, '00001000': 1})\n" ] } ], "source": [ "bpe_processor = spm.SentencePieceProcessor(model_file='384_bit_comp.model')\n", "\n", "def encode_id(bit_text):\n", "\tencoded_pieces = bpe_processor.encode_as_pieces(bit_text)\n", "\tencoded_ids = [bpe_processor.piece_to_id(s) - 3 for s in encoded_pieces]\n", "\tassert any([id_ <= 255 for id_ in encoded_ids])\n", "\tstring_ids = \"\".join([format(id_, \"02x\") for id_ in encoded_ids])\n", "\treturn string_ids\n", "\n", "def decode_id(hex_string):\n", "\tu8_array = np.frombuffer(bytes.fromhex(hex_string), dtype='