Commit 89e52096 authored by wux's avatar wux

add ocr_metric algorithm

parent b26cb23d
...@@ -9,3 +9,7 @@ dist ...@@ -9,3 +9,7 @@ dist
build build
log log
product_test product_test
res/ffmpeg-4.3.1/bin/output.mp4
res/ffmpeg-4.3.1/bin/qiji_local.mp4
venv/
venv37/
\ No newline at end of file
This diff is collapsed.
import re
import csv
import argparse
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from difflib import SequenceMatcher
title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
def init():
# 获取中文停用词列表
global stop_words
with open('chinese_stopwords.txt', 'r', encoding='utf-8') as file:
stop_words = set(line.strip() for line in file)
# 将保存的时间戳转化为秒
def change_to_second(time_str):
from datetime import datetime
time_obj = datetime.strptime(time_str, "%H:%M:%S,%f")
seconds = time_obj.hour * 3600 + time_obj.minute * 60 + \
time_obj.second + time_obj.microsecond / 1000000
return seconds
# 计算字幕的相似度
def calculate_similarity(str1, str2, method='cosine'):
if method == 'cosine':
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([str1, str2])
similarity_matrix = cosine_similarity(tfidf_matrix)
return similarity_matrix[0][1]
else :
return SequenceMatcher(None, str1, str2).ratio()
# 计算两个时间戳的时间差(单位:秒)
def calculate_time_difference(time1, time2):
return abs(time2 - time1)
def calculate_weight(x, y):
# weight = e^(-alpha * time_diff)
# 相差1s的系数为0.9
alpha = 0.11
return 1 / (alpha * (x + y) + 1)
# 检查句子中的每个单词是否都是停用词
def is_all_stopwords(sentence):
return all(word in stop_words for word in sentence)
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def read_srt_to_csv(path_srt, path_output):
with open(path_srt, 'r', encoding='utf-8-sig') as f:
srt_content = f.read() # str
# 使用正则表达式匹配时间码和字幕内容
pattern = re.compile(r'(\d+)\n([\d:,]+) --> ([\d:,]+)\n(.+?)(?=\n\d+\n|$)', re.DOTALL)
matches = pattern.findall(srt_content)
# 写入 csv 文件
with open(path_output, 'w', newline='', encoding='utf-8') as f:
csv_writer = csv.writer(f)
csv_writer.writerow(title)
for _, start, end, subtitle in matches: # 都是str格式
subtitle = re.sub(r'\{[^}]*\}', '', subtitle) # 将srt文件前的加粗等格式去掉
csv_writer.writerow([start, end, subtitle.strip()])
def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
data = pd.read_excel(path_xlsx)
with open(path_output, 'w', newline='', encoding='utf-8') as f:
csv_writer = csv.writer(f)
csv_writer.writerow(title)
for _, data1 in data.iterrows():
start, end, subtitle = data1[1], data1[3], data1[4]
if isinstance(subtitle, float) and np.isnan(subtitle):
continue
# 与srt文件格式同步
start = start.replace('.', ',')
end = end.replace('.', ',')
# print(start, end, subtitle,)
# print(type(start), type(end), type(subtitle))
csv_writer.writerow([start, end, subtitle.strip()])
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
# time_threshold设置阈值,用于判断时间差是否可接受
def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
data_srt, data_ocr = [], []
with open(path_srt, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file)
data_srt = [i for i in csv_reader]
data_srt.pop(0)
for i in range(len(data_srt)):
data_srt[i][0] = change_to_second(data_srt[i][0])
data_srt[i][1] = change_to_second(data_srt[i][1])
with open(path_ocr, 'r', encoding='utf-8') as file:
csv_reader = csv.reader(file)
data_ocr = [i for i in csv_reader]
data_ocr.pop(0)
for i in range(len(data_ocr)):
data_ocr[i][0] = change_to_second(data_ocr[i][0])
data_ocr[i][1] = change_to_second(data_ocr[i][1])
# 计算相似度
total_similarity = 0.0
total_weight = 0.0
for sub in data_srt:
max_similarity = 0.0
# 去除srt中的停用词
if is_all_stopwords(sub[2]):
continue
for sub1 in data_ocr:
x, y = abs(sub[0] - sub1[0]), abs(sub[1] - sub1[1])
if min(x, y) <= time_threshold:
# print(sub[2], sub1[2])
score = calculate_similarity(sub[2], sub1[2], 'cosine')
max_similarity = max(max_similarity, score * calculate_weight(x, y))
total_similarity += max_similarity
total_weight += 1
# print(total_similarity, total_similarity / len(data_srt), total_similarity / total_weight)
return total_similarity / len(data_srt), total_similarity / total_weight
if __name__ == '__main__':
init()
parser = argparse.ArgumentParser(description="benchmark")
# 添加命令行参数
parser.add_argument("--path_srt", required=True, type=str, help="Path of srt file, format is srt")
parser.add_argument("--path_ocr", required=True, type=str, help="Path of ocr file, format is xlsx")
parser.add_argument("--method", type=str, default='cosine', help="Select evaluation method")
parser.add_argument("--time_threshold", type=float,default=5.0, help="Allowable time frame")
args = parser.parse_args()
output_file_srt = 'deal_srt.csv'
output_file_ocr = 'deal_ocr.csv'
read_srt_to_csv(args.path_srt, output_file_srt)
read_from_xlsx(args.path_ocr, output_file_ocr)
score = measure_score(output_file_srt, output_file_ocr, args.time_threshold, args.method)
print(f'该评估算法得分: {score[1]:.5f}')
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment