InterProt ESM2 SAE Models

A set of SAE models trained on ESM2-650 activations using 1M protein sequences from UniProt. The SAE implementation mostly followed Gao et al. with Top-K activation function.

For more information, check out our preprint. Our SAEs can be viewed and interacted with on https://interprot.com.

Installation

pip install git+https://github.com/etowahadams/interprot.git

Usage

Install InterProt, load ESM and SAE

import torch
from transformers import AutoTokenizer, EsmModel
from safetensors.torch import load_file
from interprot.sae_model import SparseAutoencoder
from huggingface_hub import hf_hub_download

ESM_DIM = 1280
SAE_DIM = 4096
LAYER = 24

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ESM model
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
esm_model.to(device)
esm_model.eval()

# Load SAE model
checkpoint_path = hf_hub_download(
    repo_id="liambai/InterProt-ESM2-SAEs",
    filename="esm2_plm1280_l24_sae4096.safetensors"
)
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
sae_model.load_state_dict(load_file(checkpoint_path))
sae_model.to(device)
sae_model.eval()

ESM -> SAE inference on an amino acid sequence of length L

seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"

# Tokenize sequence and run ESM inference
inputs = tokenizer(seq, padding=True, return_tensors="pt").to(device)
with torch.no_grad():
    outputs = esm_model(**inputs, output_hidden_states=True)

# esm_layer_acts has shape (L+2, ESM_DIM), +2 for BoS and EoS tokens
esm_layer_acts = outputs.hidden_states[LAYER][0]

# Using ESM embeddings from LAYER, run SAE inference
sae_acts = sae_model.get_acts(esm_layer_acts) # (L+2, SAE_DIM)
sae_acts

Note on the default checkpoint on interprot.com

In Novermber 2024, we shared an earlier version of our layer 24 SAE on X and got a lot of amazing community support in identifying SAE features; therefore, we have kept it as the default on interprot.com. Since then, we retrained the layer 24 SAE with slightly different hyperparameters and on more sequences (1M vs. the original 100K). The new SAE is named esm2_plm1280_l24_sae4096.safetensors whereas the original is named esm2_plm1280_l24_sae4096_100k.safetensors.

We recommend using esm2_plm1280_l24_sae4096.safetensors, but if you'd like to reproduce the default SAE on interprot.com, you can use esm2_plm1280_l24_sae4096_100k.safetensors. All other layer SAEs are trained with the same configrations as esm2_plm1280_l24_sae4096.safetensors.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API was unable to determine this model's library.