Model description
This is a LogisticRegressionCV model trained on averages of patch embeddings from the Imagenette dataset. This forms the GAM of an Emb-GAM extended to images. Patch embeddings are meant to be extracted with the google/vit-base-patch16-224
ViT checkpoint.
Intended uses & limitations
This model is not intended to be used in production.
Training Procedure
Hyperparameters
The model is trained with below hyperparameters.
Click to expand
Hyperparameter |
Value |
Cs |
10 |
class_weight |
|
cv |
StratifiedKFold(n_splits=5, random_state=1, shuffle=True) |
dual |
False |
fit_intercept |
True |
intercept_scaling |
1.0 |
l1_ratios |
|
max_iter |
100 |
multi_class |
auto |
n_jobs |
|
penalty |
l2 |
random_state |
1 |
refit |
False |
scoring |
|
solver |
lbfgs |
tol |
0.0001 |
verbose |
0 |
Model Plot
The model plot is below.
LogisticRegressionCV(cv=StratifiedKFold(n_splits=5, random_state=1, shuffle=True),random_state=1, refit=False)
Please rerun this cell to show the HTML repr or trust the notebook.
Evaluation Results
You can find the details about evaluation process and the evaluation results.
Metric |
Value |
accuracy |
0.99465 |
f1 score |
0.99465 |
How to Get Started with the Model
Use the code below to get started with the model.
Click to expand
from PIL import Image
from skops import hub_utils
import torch
from transformers import AutoFeatureExtractor, AutoModel
import pickle
import os
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = AutoModel.from_pretrained("google/vit-base-patch16-224").eval().to(device)
os.mkdir("emb-gam-vit")
hub_utils.download(repo_id="Ramos-Ramos/emb-gam-vit", dst="emb-gam-vit")
with open("emb-gam-vit/model.pkl", "rb") as file:
logistic_regression = pickle.load(file)
img = Image.open("examples/english_springer.png")
inputs = {k: v.to(device) for k, v in feature_extractor(img, return_tensors='pt').items()}
with torch.no_grad():
patch_embeddings = model(**inputs).last_hidden_state[0, 1:].cpu()
pred = logistic_regression.predict(patch_embeddings.sum(dim=0, keepdim=True))
patch_contributions = logistic_regression.coef_ @ patch_embeddings.T.numpy()
Model Card Authors
This model card is written by following authors:
Patrick Ramos and Ryan Ramos
Model Card Contact
You can contact the model card authors through following channels:
[More Information Needed]
Citation
Below you can find information related to citation.
BibTeX:
@article{singh2022emb,
title={Emb-GAM: an Interpretable and Efficient Predictor using Pre-trained Language Models},
author={Singh, Chandan and Gao, Jianfeng},
journal={arXiv preprint arXiv:2209.11799},
year={2022}
}
Additional Content
confusion_matrix
![confusion_matrix](/Ramos-Ramos/emb-gam-vit/resolve/main/confusion_matrix.png)