Commit 5652328b authored by wux's avatar wux

fix:add new ocr_metric algorithm

parent c01b4948
...@@ -13,4 +13,19 @@ res/ffmpeg-4.3.1/bin/output.mp4 ...@@ -13,4 +13,19 @@ res/ffmpeg-4.3.1/bin/output.mp4
res/ffmpeg-4.3.1/bin/qiji_local.mp4 res/ffmpeg-4.3.1/bin/qiji_local.mp4
venv/ venv/
venv37/ venv37/
shenming_test shenming_test
\ No newline at end of file
cap.png
requirements3.8.txt
venv3.8-new/
webrtcvad-2.0.10-cp38-abi3-win_amd64.whl
xlsx-resource/
deal_ocr.csv
deal_srt.csv
new.srt
shenhai1.xlsx
shenhai2.xlsx
test,py
"\346\267\261\346\265\267\347\237\255\347\211\2072.xlsx"
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
\ No newline at end of file
import re import re
import csv import csv
import jieba
import argparse import argparse
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from difflib import SequenceMatcher from difflib import SequenceMatcher
title = ['起始时间(转换后)', '终止时间(转换后)', '字幕'] # title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
title = ['起始时间', '终止时间', '字幕']
def init(): def init():
# 获取中文停用词列表 # 获取中文停用词列表
...@@ -22,13 +25,25 @@ def change_to_second(time_str): ...@@ -22,13 +25,25 @@ def change_to_second(time_str):
time_obj.second + time_obj.microsecond / 1000000 time_obj.second + time_obj.microsecond / 1000000
return seconds return seconds
# 将中文句子划分,并且防止划分全部为停用词
def words_segment(str):
tmp = ','.join(jieba.cut(str))
# 将分割的句子差分成单词,也不进行划分
if is_all_stopwords(tmp) or len(list(jieba.cut(str))) == len(str) :
return str
return tmp
# 计算字幕的相似度 # 计算字幕的相似度
def calculate_similarity(str1, str2, method='cosine'): def calculate_similarity(str1, str2, method='cosine'):
if method == 'cosine': if method == 'cosine':
tfidf_vectorizer = TfidfVectorizer() str1, str2 = words_segment(str1), words_segment(str2)
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2]) tfidf_vectorizer = TfidfVectorizer(min_df=1)
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2]) # shape=[2, N]
# print(np.array(tfidf_matrix.toarray()).shape, type(tfidf_matrix), tfidf_matrix.toarray())
similarity_matrix = cosine_similarity(tfidf_matrix) similarity_matrix = cosine_similarity(tfidf_matrix)
return similarity_matrix[0][1] return similarity_matrix[0][1]
elif method == 'distance':
return -String_edit_distance(str1, str2)
else : else :
return SequenceMatcher(None, str1, str2).ratio() return SequenceMatcher(None, str1, str2).ratio()
...@@ -37,15 +52,37 @@ def calculate_time_difference(time1, time2): ...@@ -37,15 +52,37 @@ def calculate_time_difference(time1, time2):
return abs(time2 - time1) return abs(time2 - time1)
def calculate_weight(x, y): def calculate_weight(x, y):
# weight = e^(-alpha * time_diff) # # weight = e^(-alpha * time_diff)
# 相差1s的系数为0.9 # # 相差1s的系数为0.9
alpha = 0.11 # alpha = 0.11
return 1 / (alpha * (x + y) + 1) # return 1 / (alpha * (x + y) + 1)
return 1.0 # 目前不考虑时间系数
# 检查句子中的每个单词是否都是停用词 # 检查句子中的每个单词是否都是停用词
def is_all_stopwords(sentence): def is_all_stopwords(sentence):
sentence = sentence.replace(' ', '')
return all(word in stop_words for word in sentence) return all(word in stop_words for word in sentence)
# 编辑距离算法 有问题!!!!!!
def String_edit_distance(str1, str2):
n, m = len(str1), len(str2)
dp = [[0 for _ in range(m+1)] for _ in range(n+1)]
for i in range(n+1):
dp[i][0] = i
for i in range(m+1):
dp[0][i] = i
dp[0][0] = 0
for i in range(1, n+1):
for j in range(1, m+1):
if str1[i-1] == str2[j-1]:
dp[i][j] = dp[i-1][j-1]
else :
dp[i][j] = min(dp[i-1][j-1], min(dp[i][j-1], dp[i-1][j])) + 1
# print(dp[n][m], n, m)
return 1.0 * dp[n][m] / max(n, m)
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题 ### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def read_srt_to_csv(path_srt, path_output): def read_srt_to_csv(path_srt, path_output):
with open(path_srt, 'r', encoding='utf-8-sig') as f: with open(path_srt, 'r', encoding='utf-8-sig') as f:
...@@ -70,20 +107,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'): ...@@ -70,20 +107,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
csv_writer.writerow(title) csv_writer.writerow(title)
for _, data1 in data.iterrows(): for _, data1 in data.iterrows():
start, end, subtitle = data1[1], data1[3], data1[4] # print(data1[1])
start, end, subtitle = data1[0], data1[1], data1[2]
if isinstance(subtitle, float) and np.isnan(subtitle): if isinstance(subtitle, float) and np.isnan(subtitle):
continue continue
# 与srt文件格式同步 # 与srt文件格式同步
start = start.replace('.', ',') start = start.replace('.', ',')
end = end.replace('.', ',') end = end.replace('.', ',')
# print(start, end, subtitle,)
# print(type(start), type(end), type(subtitle))
csv_writer.writerow([start, end, subtitle.strip()]) csv_writer.writerow([start, end, subtitle.strip()])
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕, ### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍 ### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
# time_threshold设置阈值,用于判断时间差是否可接受 # time_threshold设置阈值,用于判断时间差是否可接受
def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'): def measure_score(path_srt, path_ocr, time_threshold=5.0, time_threshold_re=False, method='cosine'):
data_srt, data_ocr = [], [] data_srt, data_ocr = [], []
with open(path_srt, 'r', encoding='utf-8') as file: with open(path_srt, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file) csv_reader = csv.reader(file)
...@@ -105,20 +141,29 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'): ...@@ -105,20 +141,29 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
total_weight = 0.0 total_weight = 0.0
for sub in data_srt: for sub in data_srt:
max_similarity = 0.0 max_similarity = 0.0 if method != 'distance' else -1000.0
# 去除srt中的停用词 # 去除srt中的停用词
if is_all_stopwords(sub[2]): if is_all_stopwords(sub[2]):
continue continue
# subb = ""
for sub1 in data_ocr: for sub1 in data_ocr:
x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1]) x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1])
if min(x, y) <= time_threshold: if time_threshold_re:
# print(sub[2], sub1[2]) time_threshold_tmp = time_threshold
score = calculate_similarity(sub[2], sub1[2], 'cosine') else :
time_threshold_tmp = (sub[1] - sub[0]) * 0.3 # 10s允许3s的误差
if min(x, y) <= time_threshold_tmp:
score = calculate_similarity(sub[2], sub1[2], method)
# if max_similarity <= score * calculate_weight(x, y):
# subb = sub1[2]
max_similarity = max(max_similarity, score * calculate_weight(x, y)) max_similarity = max(max_similarity, score * calculate_weight(x, y))
# if max_similarity <= 0.5:
# print(max_similarity, sub[2], subb, sub[0])
total_similarity += max_similarity total_similarity += max_similarity
total_weight += 1 total_weight += 1
if method == 'distance':
# print(total_similarity, total_similarity / len(data_srt), total_similarity / total_weight) total_similarity = total_weight + total_similarity
return total_similarity / len(data_srt), total_similarity / total_weight return total_similarity / len(data_srt), total_similarity / total_weight
if __name__ == '__main__': if __name__ == '__main__':
...@@ -128,13 +173,19 @@ if __name__ == '__main__': ...@@ -128,13 +173,19 @@ if __name__ == '__main__':
# 添加命令行参数 # 添加命令行参数
parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt") parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt")
parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx") parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx")
parser.add_argument("--method", type=str, default='cosine', help="Select evaluation method") parser.add_argument("--time_threshold", type=float, default=5.0, help="Allowable time frame")
parser.add_argument("--time_threshold", type=float,default=5.0, help="Allowable time frame")
parser.add_argument("--method", type=str, default='distance',choices=['cosine', 'distance', 'sequence']
, help="Select evaluation method")
parser.add_argument("--time_threshold_re", type=bool, default=True, help="Specify whether \
time threshold is required")
args = parser.parse_args() args = parser.parse_args()
output_file_srt = 'deal_srt.csv' output_file_srt = 'deal_srt.csv'
output_file_ocr = 'deal_ocr.csv' output_file_ocr = 'deal_ocr.csv'
read_srt_to_csv(args.path_srt, output_file_srt) read_srt_to_csv(args.path_srt, output_file_srt)
read_from_xlsx(args.path_ocr, output_file_ocr) read_from_xlsx(args.path_ocr, output_file_ocr)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, args.method) score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, \
args.time_threshold_re, args.method)
print(f'该评估算法得分: {score[1]:.5f}') print(f'该评估算法得分: {score[1]:.5f}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment