Multicentury-HTR-Demo / segment_image.py
MikkoLipsanen's picture
Update segment_image.py
3351c8b verified
from huggingface_hub import hf_hub_download
from shapely.validation import make_valid
from shapely.geometry import Polygon
from ultralytics import YOLO
from PIL import Image
import numpy as np
import os
from reading_order import OrderPolygons
class SegmentImage:
"""Class for segmenting document image regions and text lines."""
def __init__(self,
line_model_path,
device,
line_iou=0.5,
region_iou=0.5,
line_overlap=0.5,
line_nms_iou=0.7,
region_nms_iou=0.3,
line_conf_threshold=0.25,
region_conf_threshold=0.25,
region_model_path=None,
order_regions=True,
region_half_precision=False,
line_half_precision=False):
# Path to text line detection model
self.line_model_path = line_model_path
# Path to text region detection model
self.region_model_path = region_model_path
# Defines the IoU threshold used in the non-maximum suppression (NMS) process to
# determine which prediction boxes should be suppressed or discarded based on their overlap with other boxes
self.line_nms_iou = line_nms_iou
self.region_nms_iou = region_nms_iou
# Defines the IoU threshold for text lines
self.line_iou = line_iou
# Defines the IoU threshold for text regions
self.region_iou = region_iou
# Defines the extent of line polygon overlap used for merging the polygons
self.line_overlap = line_overlap
# Defines confidence threshold for line detection
self.line_conf_threshold = line_conf_threshold
# Defines confidence threshold for region detection
self.region_conf_threshold = region_conf_threshold
# Defines the device to be used ('cpu', gpu '0', gpu '1' etc.)
self.device = device
# Defines whether a reading order is also estimated for the region detections
self.order_regions = order_regions
# Defines whether half precision (FP16) is used by the region and line prediction models
self.region_half_precision = region_half_precision
self.line_half_precision = line_half_precision
self.order_poly = OrderPolygons()
# Initialize segmentation model(s)
self.line_model = self.init_line_model()
if self.region_model_path:
self.region_model = self.init_region_model()
def init_line_model(self):
"""Function for initializing the line detection model."""
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id=self.line_model_path, filename="lines_20240827.pt")
line_model = YOLO(cached_model_path)
return line_model
except Exception as e:
print('Failed to load the line detection model: %s' % e)
def init_region_model(self):
"""Function for initializing the region detection model."""
try:
# Load the trained line detection model
cached_model_path = hf_hub_download(repo_id=self.region_model_path, filename="tuomiokirja_regions_04122023.pt")
region_model = YOLO(cached_model_path)
return region_model
except Exception as e:
print('Failed to load the region detection model: %s' % e)
def get_region_ids(self, coords, max_min, classes, names, box_confs, img_shape):
"""Function for creating unique id for each detected region."""
n = min(len(classes), len(coords))
res = []
for i in range(n):
# Creates a simple index-based id for each region
region_id = str(i)
# Extracts region name corresponding to the index
region_type = names[classes[i]]
poly_dict = {'coords': coords[i],
'max_min': max_min[i],
'class': str(classes[i]),
'name': region_type,
'conf': box_confs[i],
'id': region_id,
'img_shape': img_shape}
res.append(poly_dict)
return res
def get_max_min(self, polygons):
"""Creates an array with the minimum and maximum
x and y values of the input polygons."""
n_rows = len(polygons)
xy_array = np.zeros([n_rows, 4])
for i, poly in enumerate(polygons):
x = [point[0] for point in poly]
y = [point[1] for point in poly]
if x:
xy_array[i,0] = max(x)
xy_array[i,1] = min(x)
if y:
xy_array[i,2] = max(y)
xy_array[i,3] = min(y)
return xy_array
def validate_polygon(self, polygon):
""""Function for testing and correcting the validity of polygons."""
if len(polygon) > 2:
polygon = Polygon(polygon)
if not polygon.is_valid:
polygon = make_valid(polygon)
return polygon
else:
return None
def get_iou(self, poly1, poly2):
"""Function for calculating Intersection over Union (IoU) values."""
# If the polygons don't intersect, IoU is 0
iou = 0
poly1 = self.validate_polygon(poly1)
poly2 = self.validate_polygon(poly2)
if poly1 and poly2:
if poly1.intersects(poly2):
# Calculates intersection of the 2 polygons
intersect = poly1.intersection(poly2).area
# Calculates union of the 2 polygons
uni = poly1.union(poly2)
# Calculates intersection over union
iou = intersect / uni.area
return iou
def merge_polygons(self, polygons, iou_threshold, overlap_threshold = None):
"""Merges polygons that have an IoU value
above the given threshold."""
new_polygons = []
dropped = set()
# Loops over all input polygons and merges them if the
# IoU value is over the given threshold
for i in range(0, len(polygons)):
poly1 = self.validate_polygon(polygons[i])
merged = None
for j in range(i+1, len(polygons)):
poly2 = self.validate_polygon(polygons[j])
if poly1 and poly2:
if poly1.intersects(poly2):
overlap = False
intersect = poly1.intersection(poly2)
uni = poly1.union(poly2)
# Calculates intersection over union
iou = intersect.area / uni.area
if overlap_threshold:
overlap = intersect.area > (overlap_threshold * min(poly1.area, poly2.area))
if (iou > iou_threshold) or overlap:
if merged:
# If there are multiple overlapping polygons
# with IoU over the threshold, they are all merged together
merged = uni.union(merged)
dropped.add(j)
else:
merged = uni
# Polygons that are merged together are dropped from
# the list
dropped.add(i)
dropped.add(j)
if merged:
if merged.geom_type in ['GeometryCollection','MultiPolygon']:
for geom in merged.geoms:
if geom.geom_type == 'Polygon':
new_polygons.append(list(geom.exterior.coords))
elif merged.geom_type == 'Polygon':
new_polygons.append(list(merged.exterior.coords))
res = [i for j, i in enumerate(polygons) if j not in dropped]
res += new_polygons
return res
def get_region_preds(self, img):
"""Function for predicting text region coordinates."""
results = self.region_model.predict(source=img,
device=self.device,
conf=self.region_conf_threshold,
half=bool(self.region_half_precision),
iou=self.region_nms_iou)
results = results[0].cpu()
if results.masks:
# Extracts detected region polygons
coords = results.masks.xy
# Merge overlapping polygons
coords = self.merge_polygons(coords, self.region_iou)
# Maximum and minimum x and y axis values for detected polygons used for ordering the polygons
max_min = self.get_max_min(coords).tolist()
# Gets a list of the predicted class labels for detected regions
classes = results.boxes.cls.tolist()
# A dictionary with class ids as keys and class names as values
names = results.names
# Confidence values for detections
box_confs = results.boxes.conf.tolist()
# A tuple containing the shape of the original image
img_shape = results.orig_shape
res = self.get_region_ids(list(coords), max_min, classes, names, box_confs, img_shape)
return res
else:
return None
def get_line_preds(self, img):
"""Function for predicting text line coordinates."""
results = self.line_model.predict(source=img,
device=self.device,
conf=self.line_conf_threshold,
half=bool(self.line_half_precision),
iou=self.line_nms_iou)
results = results[0].cpu()
if results.masks:
# Detected text line polygons
coords = results.masks.xy
# Merge overlapping polygons
coords = self.merge_polygons(coords, self.line_iou, self.line_overlap)
# Maximum and minimum x and y axis values for detected polygons
max_min = self.get_max_min(coords).tolist()
# Confidence values for detections
box_confs = results.boxes.conf.tolist()
res_dict = {'coords': list(coords), 'max_min': max_min, 'confs': box_confs}
return res_dict
else:
return None
def get_dist(self, line_polygon, regions):
"""Function for finding the closest region to the text line."""
dist, reg_id = 1000000, None
line_polygon = self.validate_polygon(line_polygon)
if line_polygon:
for region in regions:
# Calculates dictance between line and regions polygons
region_polygon = self.validate_polygon(region['coords'])
if region_polygon:
line_reg_dist = line_polygon.distance(region_polygon)
if line_reg_dist < dist:
dist = line_reg_dist
reg_id = region['id']
return reg_id
def get_line_regions(self, lines, regions):
"""Function for connecting each text line to one region."""
lines_list = []
for i in range(len(lines['coords'])):
iou, reg_id, conf = 0, '', 0.0
max_min = [0.0, 0.0, 0.0, 0.0]
polygon = lines['coords'][i]
for region in regions:
line_reg_iou = self.get_iou(polygon, region['coords'])
if line_reg_iou > iou:
iou = line_reg_iou
reg_id = region['id']
# If line polygon does not intersect with any region, a distance metric is used for defining
# the region that the line belongs to
if iou == 0:
reg_id = self.get_dist(polygon, regions)
if (len(lines['max_min']) - 1) >= i:
max_min = lines['max_min'][i]
if (len(lines['confs']) - 1) >= i:
conf = lines['confs'][i]
new_line = {'polygon': polygon, 'reg_id': reg_id, 'max_min': max_min, 'conf': conf}
lines_list.append(new_line)
return lines_list
def order_regions_lines(self, lines, regions):
"""Function for ordering line predictions inside each region."""
regions_with_rows = []
region_max_mins = []
for i, region in enumerate(regions):
line_max_mins = []
line_confs = []
line_polygons = []
for line in lines:
if line['reg_id'] == region['id']:
line_max_mins.append(line['max_min'])
line_confs.append(line['conf'])
line_polygons.append(line['polygon'])
if line_polygons:
# If one or more lines are connected to a region, line order inside the region is defined
# and the predicted text lines are joined in the same python dict
line_order = self.order_poly.order(line_max_mins)
line_polygons = [line_polygons[i] for i in line_order]
line_confs = [line_confs[i] for i in line_order]
new_region = {'region_coords': region['coords'],
'region_name': region['name'],
'lines': line_polygons,
'line_confs': line_confs,
'region_conf': region['conf'],
'img_shape': region['img_shape']}
region_max_mins.append(region['max_min'])
regions_with_rows.append(new_region)
else:
continue
# Creates an ordering of the detected regions based on their polygon coordinates
if self.order_regions:
region_order = self.order_poly.order(region_max_mins)
regions_with_rows = [regions_with_rows[i] for i in region_order]
return regions_with_rows
def get_default_region(self, image):
"""Function for creating a default region if no regions are detected."""
w, h = image.size
region = {'coords': [[0.0, 0.0], [w, 0.0], [w, h], [0.0, h]],
'max_min': [w, 0.0, h, 0.0],
'class': '0',
'name': "paragraph",
'conf': 0.0,
'id': '0',
'img_shape': (h, w)}
return [region]
def get_segmentation(self, image):
"""Segment input image into ordered text lines or ordered text regions and text lines."""
line_preds = self.get_line_preds(image)
if line_preds:
# If region detection model is defined, text regions and text lines are detected
region_preds = self.get_region_preds(image)
if not region_preds:
region_preds = self.get_default_region(image)
print(f'No regions detected from image {image}')
lines_with_regions = self.get_line_regions(line_preds, region_preds)
ordered_regions = self.order_regions_lines(lines_with_regions, region_preds)
return ordered_regions
else:
print(f'No text lines detected from image {image}')
return None