Commit ec38f0a0 authored by wux51's avatar wux51

fix:origin ocr algorithm issue

parent 5652328b
......@@ -28,4 +28,16 @@ shenhai1.xlsx
shenhai2.xlsx
test,py
"\346\267\261\346\265\267\347\237\255\347\211\2072.xlsx"
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
\ No newline at end of file
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
11.py
222.py
cap/
cap1597.png
cap831.png
deal.py
deal_movie.py
movie_1.txt
movie_pro.txt
res/.paddleocr/2.3.0.1/ocr/paddleocr/
script1.py
test/
\ No newline at end of file
......@@ -45,6 +45,11 @@ cur_det_model_dir = paddle_dir + "det/ch/ch_PP-OCRv2_det_infer"
cur_rec_model_dir = paddle_dir + "rec/ch/ch_PP-OCRv2_rec_infer"
ocr = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False, use_gpu=False, cls_model_dir=cur_cls_model_dir, det_model_dir=cur_det_model_dir, rec_model_dir=cur_rec_model_dir)
# paddle_dir = "res/.paddleocr/2.3.0.1/ocr/paddleocr/"
# cur_det_model_dir = paddle_dir + "ch_PP-OCRv4_det_infer"
# cur_rec_model_dir = paddle_dir + "ch_PP-OCRv4_rec_infer"
# ocr = PaddleOCR(use_angle_cls=True, lang="ch", show_log=False, use_gpu=False, det_model_dir=cur_det_model_dir, rec_model_dir=cur_rec_model_dir)
# 正常语速为4字/秒
normal_speed = 4
......@@ -251,6 +256,16 @@ def normalize(text: str) -> str:
text = text + ')'
return text
def extract_white_prior(img, threshold=200):
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# 设定阈值,将非白色部分二值化为黑色
ret, binary_image = cv2.threshold(gray, threshold, 255, cv2.THRESH_BINARY)
return binary_image
index = 0
def detect_subtitle(org_img: np.ndarray) -> Tuple[Union[str, None], float]:
"""检测当前画面得到字幕信息
......@@ -298,6 +313,9 @@ def detect_subtitle(org_img: np.ndarray) -> Tuple[Union[str, None], float]:
if img.shape[1] < 1000:
img = cv2.resize(img, (int(img.shape[1] * 1.5), int(img.shape[0] * 1.5)))
global index
# img = extract_white_prior(img)
cv2.imwrite(f'./cap/cap{index}.png', img)
index = index + 1
print(">>>>>>>>>>>>>>>>>>>>>>>>>>>new log" + str(index - 1))
......@@ -311,6 +329,7 @@ def detect_subtitle(org_img: np.ndarray) -> Tuple[Union[str, None], float]:
possible_txt = []
conf = 0
print('res --------->', res)
res.sort(key=lambda rect: rect[0][0][0] + rect[0][1][0]) # 按照中心点排序
for x in res:
# cv2.imshow("cut", img)
# cv2.waitKey(0)
......@@ -400,6 +419,7 @@ def process_video(video_path: str, begin: float, end: float, book_path: str, she
cur_time = video.get(cv2.CAP_PROP_POS_MSEC) / 1000
# 判断当前帧是否已超限制
# end 主要用来判断是否越界
if cur_time > end:
if cur_time - end_time > 1:
print('--------------------------------------------------')
......@@ -606,7 +626,7 @@ def add_to_list(mainWindow: MainWindow, element_type: str, li: list, ocr_h : int
mainWindow.last_aside_index = len(mainWindow.projectContext.all_elements) - 1
# end_time 主要用来判断是否越界
def detect_with_ocr(video_path: str, book_path: str, start_time: float, end_time: float, state=None, mainWindow: MainWindow=None):
"""使用ocr检测视频获取字幕并输出旁白推荐
......
import re
import sys
import csv
import jieba
import argparse
......@@ -7,6 +8,7 @@ import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from difflib import SequenceMatcher
from tqdm import tqdm
# title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
title = ['起始时间', '终止时间', '字幕']
......@@ -33,6 +35,13 @@ def words_segment(str):
return str
return tmp
# 判断是否从中英文字幕中提取中文
def extract_info(str, has_english=False):
if not has_english:
return str
chinese_text = re.findall(r'[\u4e00-\u9fff]+', str)
return ' '.join(chinese_text)
# 计算字幕的相似度
def calculate_similarity(str1, str2, method='cosine'):
if method == 'cosine':
......@@ -85,12 +94,29 @@ def String_edit_distance(str1, str2):
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def read_srt_to_csv(path_srt, path_output):
with open(path_srt, 'r', encoding='utf-8-sig') as f:
srt_content = f.read() # str
try:
with open(path_srt, 'r', encoding='utf-8-sig') as f:
srt_content = f.read() # str
except UnicodeDecodeError:
print(f"编码错误,已经切换到utf-16编码")
try:
with open(path_srt, 'r', encoding='utf-16') as f:
srt_content = f.read() # str
except:
print(f"请选择utf-8或utf-16编码形式的srt文件")
sys.exit(1)
# 使用正则表达式匹配时间码和字幕内容
pattern = re.compile(r'(\d+)\n([\d:,]+) --> ([\d:,]+)\n(.+?)(?=\n\d+\n|$)', re.DOTALL)
matches = pattern.findall(srt_content)
has_english = []
for i in range(5):
idx = np.random.randint(len(matches))
pattern = re.compile(r'[a-zA-Z]')
has_english.append(bool(pattern.search(matches[idx][3])))
has_english = all(has_english)
print('!'*20, has_english)
# 写入 csv 文件
with open(path_output, 'w', newline='', encoding='utf-8') as f:
csv_writer = csv.writer(f)
......@@ -98,7 +124,7 @@ def read_srt_to_csv(path_srt, path_output):
for _, start, end, subtitle in matches: # 都是str格式
subtitle = re.sub(r'\{[^}]*\}', '', subtitle) # 将srt文件前的加粗等格式去掉
csv_writer.writerow([start, end, subtitle.strip()])
csv_writer.writerow([start, end, extract_info(subtitle.strip(), has_english)])
def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
data = pd.read_excel(path_xlsx)
......@@ -139,13 +165,13 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, time_threshold_re=Fals
# 计算相似度
total_similarity = 0.0
total_weight = 0.0
for sub in data_srt:
max_similarity = 0.0 if method != 'distance' else -1000.0
txt1 = []
for sub in tqdm(data_srt, desc="Processing", ncols=100):
max_similarity = 0.0 if method != 'distance' else -1.0
# 去除srt中的停用词
if is_all_stopwords(sub[2]):
continue
# subb = ""
subb = ""
for sub1 in data_ocr:
x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1])
if time_threshold_re:
......@@ -154,15 +180,20 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, time_threshold_re=Fals
time_threshold_tmp = (sub[1] - sub[0]) * 0.3 # 10s允许3s的误差
if min(x, y) <= time_threshold_tmp:
score = calculate_similarity(sub[2], sub1[2], method)
# if max_similarity <= score * calculate_weight(x, y):
# subb = sub1[2]
if max_similarity <= score * calculate_weight(x, y):
subb = sub1[2]
max_similarity = max(max_similarity, score * calculate_weight(x, y))
# if max_similarity <= 0.5:
# print(max_similarity, sub[2], subb, sub[0])
if max_similarity <= -0.5:
# print(max_similarity, sub[2], subb, sub[0])
txt1.append(' !!! '.join([str(max_similarity), sub[2], subb, str(sub[0])]))
total_similarity += max_similarity
total_weight += 1
if method == 'distance':
total_similarity = total_weight + total_similarity
with open('movie_pro.txt', 'w', encoding='utf-8') as f:
for i in txt1:
f.write(i + '\n')
return total_similarity / len(data_srt), total_similarity / total_weight
......@@ -188,4 +219,8 @@ if __name__ == '__main__':
read_from_xlsx(args.path_ocr, output_file_ocr)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, \
args.time_threshold_re, args.method)
print(f'该评估算法得分: {score[1]:.5f}')
\ No newline at end of file
print(f'该评估算法得分: {100 * score[1]:.3f}')
# python ocr_metric.py --path_srt test/new/movie_1.srt --path_ocr ../测试/the-swan-v3/The.Swan-zimu.xlsx --time_threshold 10
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment