Commit 5652328b authored by wux's avatar wux

fix:add new ocr_metric algorithm

parent c01b4948
......@@ -13,4 +13,19 @@ res/ffmpeg-4.3.1/bin/output.mp4
res/ffmpeg-4.3.1/bin/qiji_local.mp4
venv/
venv37/
shenming_test
\ No newline at end of file
shenming_test
cap.png
requirements3.8.txt
venv3.8-new/
webrtcvad-2.0.10-cp38-abi3-win_amd64.whl
xlsx-resource/
deal_ocr.csv
deal_srt.csv
new.srt
shenhai1.xlsx
shenhai2.xlsx
test,py
"\346\267\261\346\265\267\347\237\255\347\211\2072.xlsx"
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
\ No newline at end of file
import re
import csv
import jieba
import argparse
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from difflib import SequenceMatcher
title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
# title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
title = ['起始时间', '终止时间', '字幕']
def init():
# 获取中文停用词列表
......@@ -22,13 +25,25 @@ def change_to_second(time_str):
time_obj.second + time_obj.microsecond / 1000000
return seconds
# 将中文句子划分,并且防止划分全部为停用词
def words_segment(str):
tmp = ','.join(jieba.cut(str))
# 将分割的句子差分成单词,也不进行划分
if is_all_stopwords(tmp) or len(list(jieba.cut(str))) == len(str) :
return str
return tmp
# 计算字幕的相似度
def calculate_similarity(str1, str2, method='cosine'):
if method == 'cosine':
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2])
str1, str2 = words_segment(str1), words_segment(str2)
tfidf_vectorizer = TfidfVectorizer(min_df=1)
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2]) # shape=[2, N]
# print(np.array(tfidf_matrix.toarray()).shape, type(tfidf_matrix), tfidf_matrix.toarray())
similarity_matrix = cosine_similarity(tfidf_matrix)
return similarity_matrix[0][1]
elif method == 'distance':
return -String_edit_distance(str1, str2)
else :
return SequenceMatcher(None, str1, str2).ratio()
......@@ -37,15 +52,37 @@ def calculate_time_difference(time1, time2):
return abs(time2 - time1)
def calculate_weight(x, y):
# weight = e^(-alpha * time_diff)
# 相差1s的系数为0.9
alpha = 0.11
return 1 / (alpha * (x + y) + 1)
# # weight = e^(-alpha * time_diff)
# # 相差1s的系数为0.9
# alpha = 0.11
# return 1 / (alpha * (x + y) + 1)
return 1.0 # 目前不考虑时间系数
# 检查句子中的每个单词是否都是停用词
def is_all_stopwords(sentence):
sentence = sentence.replace(' ', '')
return all(word in stop_words for word in sentence)
# 编辑距离算法 有问题!!!!!!
def String_edit_distance(str1, str2):
n, m = len(str1), len(str2)
dp = [[0 for _ in range(m+1)] for _ in range(n+1)]
for i in range(n+1):
dp[i][0] = i
for i in range(m+1):
dp[0][i] = i
dp[0][0] = 0
for i in range(1, n+1):
for j in range(1, m+1):
if str1[i-1] == str2[j-1]:
dp[i][j] = dp[i-1][j-1]
else :
dp[i][j] = min(dp[i-1][j-1], min(dp[i][j-1], dp[i-1][j])) + 1
# print(dp[n][m], n, m)
return 1.0 * dp[n][m] / max(n, m)
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def read_srt_to_csv(path_srt, path_output):
with open(path_srt, 'r', encoding='utf-8-sig') as f:
......@@ -70,20 +107,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
csv_writer.writerow(title)
for _, data1 in data.iterrows():
start, end, subtitle = data1[1], data1[3], data1[4]
# print(data1[1])
start, end, subtitle = data1[0], data1[1], data1[2]
if isinstance(subtitle, float) and np.isnan(subtitle):
continue
# 与srt文件格式同步
start = start.replace('.', ',')
end = end.replace('.', ',')
# print(start, end, subtitle,)
# print(type(start), type(end), type(subtitle))
csv_writer.writerow([start, end, subtitle.strip()])
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
# time_threshold设置阈值,用于判断时间差是否可接受
def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
def measure_score(path_srt, path_ocr, time_threshold=5.0, time_threshold_re=False, method='cosine'):
data_srt, data_ocr = [], []
with open(path_srt, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file)
......@@ -105,20 +141,29 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
total_weight = 0.0
for sub in data_srt:
max_similarity = 0.0
max_similarity = 0.0 if method != 'distance' else -1000.0
# 去除srt中的停用词
if is_all_stopwords(sub[2]):
continue
# subb = ""
for sub1 in data_ocr:
x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1])
if min(x, y) <= time_threshold:
# print(sub[2], sub1[2])
score = calculate_similarity(sub[2], sub1[2], 'cosine')
if time_threshold_re:
time_threshold_tmp = time_threshold
else :
time_threshold_tmp = (sub[1] - sub[0]) * 0.3 # 10s允许3s的误差
if min(x, y) <= time_threshold_tmp:
score = calculate_similarity(sub[2], sub1[2], method)
# if max_similarity <= score * calculate_weight(x, y):
# subb = sub1[2]
max_similarity = max(max_similarity, score * calculate_weight(x, y))
# if max_similarity <= 0.5:
# print(max_similarity, sub[2], subb, sub[0])
total_similarity += max_similarity
total_weight += 1
# print(total_similarity, total_similarity / len(data_srt), total_similarity / total_weight)
if method == 'distance':
total_similarity = total_weight + total_similarity
return total_similarity / len(data_srt), total_similarity / total_weight
if __name__ == '__main__':
......@@ -128,13 +173,19 @@ if __name__ == '__main__':
# 添加命令行参数
parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt")
parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx")
parser.add_argument("--method", type=str, default='cosine', help="Select evaluation method")
parser.add_argument("--time_threshold", type=float,default=5.0, help="Allowable time frame")
parser.add_argument("--time_threshold", type=float, default=5.0, help="Allowable time frame")
parser.add_argument("--method", type=str, default='distance',choices=['cosine', 'distance', 'sequence']
, help="Select evaluation method")
parser.add_argument("--time_threshold_re", type=bool, default=True, help="Specify whether \
time threshold is required")
args = parser.parse_args()
output_file_srt = 'deal_srt.csv'
output_file_ocr = 'deal_ocr.csv'
read_srt_to_csv(args.path_srt, output_file_srt)
read_from_xlsx(args.path_ocr, output_file_ocr)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, args.method)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, \
args.time_threshold_re, args.method)
print(f'该评估算法得分: {score[1]:.5f}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment