Spaces:
Running
on
T4
Running
on
T4
from huggingface_hub import hf_hub_download | |
from shapely.validation import make_valid | |
from shapely.geometry import Polygon | |
from ultralytics import YOLO | |
from PIL import Image | |
import numpy as np | |
import os | |
from reading_order import OrderPolygons | |
class SegmentImage: | |
"""Class for segmenting document image regions and text lines.""" | |
def __init__(self, | |
line_model_path, | |
device, | |
line_iou=0.5, | |
region_iou=0.5, | |
line_overlap=0.5, | |
line_nms_iou=0.7, | |
region_nms_iou=0.3, | |
line_conf_threshold=0.25, | |
region_conf_threshold=0.25, | |
region_model_path=None, | |
order_regions=True, | |
region_half_precision=False, | |
line_half_precision=False): | |
# Path to text line detection model | |
self.line_model_path = line_model_path | |
# Path to text region detection model | |
self.region_model_path = region_model_path | |
# Defines the IoU threshold used in the non-maximum suppression (NMS) process to | |
# determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes | |
self.line_nms_iou = line_nms_iou | |
self.region_nms_iou = region_nms_iou | |
# Defines the IoU threshold for text lines | |
self.line_iou = line_iou | |
# Defines the IoU threshold for text regions | |
self.region_iou = region_iou | |
# Defines the extent of line polygon overlap used for merging the polygons | |
self.line_overlap = line_overlap | |
# Defines confidence threshold for line detection | |
self.line_conf_threshold = line_conf_threshold | |
# Defines confidence threshold for region detection | |
self.region_conf_threshold = region_conf_threshold | |
# Defines the device to be used ('cpu', gpu '0', gpu '1' etc.) | |
self.device = device | |
# Defines whether a reading order is also estimated for the region detections | |
self.order_regions = order_regions | |
# Defines whether half precision (FP16) is used by the region and line prediction models | |
self.region_half_precision = region_half_precision | |
self.line_half_precision = line_half_precision | |
self.order_poly = OrderPolygons() | |
# Initialize segmentation model(s) | |
self.line_model = self.init_line_model() | |
if self.region_model_path: | |
self.region_model = self.init_region_model() | |
def init_line_model(self): | |
"""Function for initializing the line detection model.""" | |
try: | |
# Load the trained line detection model | |
cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt") | |
line_model = YOLO(cached_model_path) | |
return line_model | |
except Exception as e: | |
print('Failed to load the line detection model: %s' % e) | |
def init_region_model(self): | |
"""Function for initializing the region detection model.""" | |
try: | |
# Load the trained line detection model | |
cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt") | |
region_model = YOLO(cached_model_path) | |
return region_model | |
except Exception as e: | |
print('Failed to load the region detection model: %s' % e) | |
def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape): | |
"""Function for creating unique id for each detected region.""" | |
n = min(len(classes), len(coords)) | |
res = [] | |
for i in range(n): | |
# Creates a simple index-based id for each region | |
region_id = str(i) | |
# Extracts region name corresponding to the index | |
region_type = names[classes[i]] | |
poly_dict = {'coords': coords[i], | |
'max_min': max_min[i], | |
'class': str(classes[i]), | |
'name': region_type, | |
'conf': box_confs[i], | |
'id': region_id, | |
'img_shape': img_shape} | |
res.append(poly_dict) | |
return res | |
def get_max_min(self, polygons): | |
"""Creates an array with the minimum and maximum | |
x and y values of the input polygons.""" | |
n_rows = len(polygons) | |
xy_array = np.zeros([n_rows, 4]) | |
for i, poly in enumerate(polygons): | |
x = [point[0] for point in poly] | |
y = [point[1] for point in poly] | |
if x: | |
xy_array[i,0] = max(x) | |
xy_array[i,1] = min(x) | |
if y: | |
xy_array[i,2] = max(y) | |
xy_array[i,3] = min(y) | |
return xy_array | |
def validate_polygon(self, polygon): | |
""""Function for testing and correcting the validity of polygons.""" | |
if len(polygon) > 2: | |
polygon = Polygon(polygon) | |
if not polygon.is_valid: | |
polygon = make_valid(polygon) | |
return polygon | |
else: | |
return None | |
def get_iou(self, poly1, poly2): | |
"""Function for calculating Intersection over Union (IoU) values.""" | |
# If the polygons don't intersect, IoU is 0 | |
iou = 0 | |
poly1 = self.validate_polygon(poly1) | |
poly2 = self.validate_polygon(poly2) | |
if poly1 and poly2: | |
if poly1.intersects(poly2): | |
# Calculates intersection of the 2 polygons | |
intersect = poly1.intersection(poly2).area | |
# Calculates union of the 2 polygons | |
uni = poly1.union(poly2) | |
# Calculates intersection over union | |
iou = intersect / uni.area | |
return iou | |
def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None): | |
"""Merges polygons that have an IoU value | |
above the given threshold.""" | |
new_polygons = [] | |
dropped = set() | |
# Loops over all input polygons and merges them if the | |
# IoU value is over the given threshold | |
for i in range(0, len(polygons)): | |
poly1 = self.validate_polygon(polygons[i]) | |
merged = None | |
for j in range(i+1, len(polygons)): | |
poly2 = self.validate_polygon(polygons[j]) | |
if poly1 and poly2: | |
if poly1.intersects(poly2): | |
overlap = False | |
intersect = poly1.intersection(poly2) | |
uni = poly1.union(poly2) | |
# Calculates intersection over union | |
iou = intersect.area / uni.area | |
if overlap_threshold: | |
overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area)) | |
if (iou > iou_threshold) or overlap: | |
if merged: | |
# If there are multiple overlapping polygons | |
# with IoU over the threshold, they are all merged together | |
merged = uni.union(merged) | |
dropped.add(j) | |
else: | |
merged = uni | |
# Polygons that are merged together are dropped from | |
# the list | |
dropped.add(i) | |
dropped.add(j) | |
if merged: | |
if merged.geom_type in ['GeometryCollection','MultiPolygon']: | |
for geom in merged.geoms: | |
if geom.geom_type == 'Polygon': | |
new_polygons.append(list(geom.exterior.coords)) | |
elif merged.geom_type == 'Polygon': | |
new_polygons.append(list(merged.exterior.coords)) | |
res = [i for j, i in enumerate(polygons) if j not in dropped] | |
res += new_polygons | |
return res | |
def get_region_preds(self, img): | |
"""Function for predicting text region coordinates.""" | |
results = self.region_model.predict(source=img, | |
device=self.device, | |
conf=self.region_conf_threshold, | |
half=bool(self.region_half_precision), | |
iou=self.region_nms_iou) | |
results = results[0].cpu() | |
if results.masks: | |
# Extracts detected region polygons | |
coords = results.masks.xy | |
# Merge overlapping polygons | |
coords = self.merge_polygons(coords, self.region_iou) | |
# Maximum and minimum x and y axis values for detected polygons used for ordering the polygons | |
max_min = self.get_max_min(coords).tolist() | |
# Gets a list of the predicted class labels for detected regions | |
classes = results.boxes.cls.tolist() | |
# A dictionary with class ids as keys and class names as values | |
names = results.names | |
# Confidence values for detections | |
box_confs = results.boxes.conf.tolist() | |
# A tuple containing the shape of the original image | |
img_shape = results.orig_shape | |
res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape) | |
return res | |
else: | |
return None | |
def get_line_preds(self, img): | |
"""Function for predicting text line coordinates.""" | |
results = self.line_model.predict(source=img, | |
device=self.device, | |
conf=self.line_conf_threshold, | |
half=bool(self.line_half_precision), | |
iou=self.line_nms_iou) | |
results = results[0].cpu() | |
if results.masks: | |
# Detected text line polygons | |
coords = results.masks.xy | |
# Merge overlapping polygons | |
coords = self.merge_polygons(coords, self.line_iou, self.line_overlap) | |
# Maximum and minimum x and y axis values for detected polygons | |
max_min = self.get_max_min(coords).tolist() | |
# Confidence values for detections | |
box_confs = results.boxes.conf.tolist() | |
res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs} | |
return res_dict | |
else: | |
return None | |
def get_dist(self, line_polygon, regions): | |
"""Function for finding the closest region to the text line.""" | |
dist, reg_id = 1000000, None | |
line_polygon = self.validate_polygon(line_polygon) | |
if line_polygon: | |
for region in regions: | |
# Calculates dictance between line and regions polygons | |
region_polygon = self.validate_polygon(region['coords']) | |
if region_polygon: | |
line_reg_dist = line_polygon.distance(region_polygon) | |
if line_reg_dist < dist: | |
dist = line_reg_dist | |
reg_id = region['id'] | |
return reg_id | |
def get_line_regions(self, lines, regions): | |
"""Function for connecting each text line to one region.""" | |
lines_list = [] | |
for i in range(len(lines['coords'])): | |
iou, reg_id, conf = 0, '', 0.0 | |
max_min = [0.0, 0.0, 0.0, 0.0] | |
polygon = lines['coords'][i] | |
for region in regions: | |
line_reg_iou = self.get_iou(polygon, region['coords']) | |
if line_reg_iou > iou: | |
iou = line_reg_iou | |
reg_id = region['id'] | |
# If line polygon does not intersect with any region, a distance metric is used for defining | |
# the region that the line belongs to | |
if iou == 0: | |
reg_id = self.get_dist(polygon, regions) | |
if (len(lines['max_min']) - 1) >= i: | |
max_min = lines['max_min'][i] | |
if (len(lines['confs']) - 1) >= i: | |
conf = lines['confs'][i] | |
new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf} | |
lines_list.append(new_line) | |
return lines_list | |
def order_regions_lines(self, lines, regions): | |
"""Function for ordering line predictions inside each region.""" | |
regions_with_rows = [] | |
region_max_mins = [] | |
for i, region in enumerate(regions): | |
line_max_mins = [] | |
line_confs = [] | |
line_polygons = [] | |
for line in lines: | |
if line['reg_id'] == region['id']: | |
line_max_mins.append(line['max_min']) | |
line_confs.append(line['conf']) | |
line_polygons.append(line['polygon']) | |
if line_polygons: | |
# If one or more lines are connected to a region, line order inside the region is defined | |
# and the predicted text lines are joined in the same python dict | |
line_order = self.order_poly.order(line_max_mins) | |
line_polygons = [line_polygons[i] for i in line_order] | |
line_confs = [line_confs[i] for i in line_order] | |
new_region = {'region_coords': region['coords'], | |
'region_name': region['name'], | |
'lines': line_polygons, | |
'line_confs': line_confs, | |
'region_conf': region['conf'], | |
'img_shape': region['img_shape']} | |
region_max_mins.append(region['max_min']) | |
regions_with_rows.append(new_region) | |
else: | |
continue | |
# Creates an ordering of the detected regions based on their polygon coordinates | |
if self.order_regions: | |
region_order = self.order_poly.order(region_max_mins) | |
regions_with_rows = [regions_with_rows[i] for i in region_order] | |
return regions_with_rows | |
def get_default_region(self, image): | |
"""Function for creating a default region if no regions are detected.""" | |
w, h = image.size | |
region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]], | |
'max_min': [w, 0.0, h, 0.0], | |
'class': '0', | |
'name': "paragraph", | |
'conf': 0.0, | |
'id': '0', | |
'img_shape': (h, w)} | |
return [region] | |
def get_segmentation(self, image): | |
"""Segment input image into ordered text lines or ordered text regions and text lines.""" | |
line_preds = self.get_line_preds(image) | |
if line_preds: | |
# If region detection model is defined, text regions and text lines are detected | |
region_preds = self.get_region_preds(image) | |
if not region_preds: | |
region_preds = self.get_default_region(image) | |
print(f'No regions detected from image {image}') | |
lines_with_regions = self.get_line_regions(line_preds, region_preds) | |
ordered_regions = self.order_regions_lines(lines_with_regions, region_preds) | |
return ordered_regions | |
else: | |
print(f'No text lines detected from image {image}') | |
return None | |