Spaces:
Running
on
L4
Running
on
L4
File size: 3,121 Bytes
a660631 ef5a142 a660631 f521e88 a660631 f521e88 a660631 ef5a142 f521e88 a660631 ef5a142 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 ef5a142 f521e88 a660631 ef5a142 f521e88 a660631 ef5a142 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import gc
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable
import numpy as np
import PIL.Image
import torch
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LineartAnimeDetector,
LineartDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
)
from controlnet_aux.util import HWC3
from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self) -> None:
self.model: Callable = None # type: ignore
self.name = ""
def load(self, name: str) -> None: # noqa: C901, PLR0912
if name == self.name:
return
if name == "HED":
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
elif name == "Midas":
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
elif name == "MLSD":
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
elif name == "Openpose":
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
elif name == "PidiNet":
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
elif name == "NormalBae":
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
elif name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
elif name == "LineartAnime":
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
elif name == "Canny":
self.model = CannyDetector()
elif name == "ContentShuffle":
self.model = ContentShuffleDetector()
elif name == "DPT":
self.model = DepthEstimator()
elif name == "UPerNet":
self.model = ImageSegmentor()
else:
raise ValueError
torch.cuda.empty_cache()
gc.collect()
self.name = name
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image: # noqa: ANN003
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image)
if self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
return self.model(image, **kwargs)
|