Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
A
accessibility_movie_2
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
赵心治
accessibility_movie_2
Commits
5652328b
Commit
5652328b
authored
Oct 09, 2023
by
wux
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix:add new ocr_metric algorithm
parent
c01b4948
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
19 deletions
+87
-19
.gitignore
.gitignore
+16
-0
ocr_metric.py
ocr_metric.py
+71
-19
No files found.
.gitignore
View file @
5652328b
...
@@ -14,3 +14,18 @@ res/ffmpeg-4.3.1/bin/qiji_local.mp4
...
@@ -14,3 +14,18 @@ res/ffmpeg-4.3.1/bin/qiji_local.mp4
venv/
venv/
venv37/
venv37/
shenming_test
shenming_test
cap.png
requirements3.8.txt
venv3.8-new/
webrtcvad-2.0.10-cp38-abi3-win_amd64.whl
xlsx-resource/
deal_ocr.csv
deal_srt.csv
new.srt
shenhai1.xlsx
shenhai2.xlsx
test,py
"\346\267\261\346\265\267\347\237\255\347\211\2072.xlsx"
"\346\267\261\346\265\267\347\237\255\347\211\207origin.xlsx"
\ No newline at end of file
ocr_metric.py
View file @
5652328b
import
re
import
re
import
csv
import
csv
import
jieba
import
argparse
import
argparse
import
pandas
as
pd
import
pandas
as
pd
import
numpy
as
np
import
numpy
as
np
from
sklearn.feature_extraction.text
import
TfidfVectorizer
from
sklearn.feature_extraction.text
import
TfidfVectorizer
from
sklearn.metrics.pairwise
import
cosine_similarity
from
sklearn.metrics.pairwise
import
cosine_similarity
from
difflib
import
SequenceMatcher
from
difflib
import
SequenceMatcher
title
=
[
'起始时间(转换后)'
,
'终止时间(转换后)'
,
'字幕'
]
# title = ['起始时间(转换后)', '终止时间(转换后)', '字幕']
title
=
[
'起始时间'
,
'终止时间'
,
'字幕'
]
def
init
():
def
init
():
# 获取中文停用词列表
# 获取中文停用词列表
...
@@ -22,13 +25,25 @@ def change_to_second(time_str):
...
@@ -22,13 +25,25 @@ def change_to_second(time_str):
time_obj
.
second
+
time_obj
.
microsecond
/
1000000
time_obj
.
second
+
time_obj
.
microsecond
/
1000000
return
seconds
return
seconds
# 将中文句子划分,并且防止划分全部为停用词
def
words_segment
(
str
):
tmp
=
','
.
join
(
jieba
.
cut
(
str
))
# 将分割的句子差分成单词,也不进行划分
if
is_all_stopwords
(
tmp
)
or
len
(
list
(
jieba
.
cut
(
str
)))
==
len
(
str
)
:
return
str
return
tmp
# 计算字幕的相似度
# 计算字幕的相似度
def
calculate_similarity
(
str1
,
str2
,
method
=
'cosine'
):
def
calculate_similarity
(
str1
,
str2
,
method
=
'cosine'
):
if
method
==
'cosine'
:
if
method
==
'cosine'
:
tfidf_vectorizer
=
TfidfVectorizer
()
str1
,
str2
=
words_segment
(
str1
),
words_segment
(
str2
)
tfidf_matrix
=
tfidf_vectorizer
.
fit_transform
([
str1
,
str2
])
tfidf_vectorizer
=
TfidfVectorizer
(
min_df
=
1
)
tfidf_matrix
=
tfidf_vectorizer
.
fit_transform
([
str1
,
str2
])
# shape=[2, N]
# print(np.array(tfidf_matrix.toarray()).shape, type(tfidf_matrix), tfidf_matrix.toarray())
similarity_matrix
=
cosine_similarity
(
tfidf_matrix
)
similarity_matrix
=
cosine_similarity
(
tfidf_matrix
)
return
similarity_matrix
[
0
][
1
]
return
similarity_matrix
[
0
][
1
]
elif
method
==
'distance'
:
return
-
String_edit_distance
(
str1
,
str2
)
else
:
else
:
return
SequenceMatcher
(
None
,
str1
,
str2
)
.
ratio
()
return
SequenceMatcher
(
None
,
str1
,
str2
)
.
ratio
()
...
@@ -37,15 +52,37 @@ def calculate_time_difference(time1, time2):
...
@@ -37,15 +52,37 @@ def calculate_time_difference(time1, time2):
return
abs
(
time2
-
time1
)
return
abs
(
time2
-
time1
)
def
calculate_weight
(
x
,
y
):
def
calculate_weight
(
x
,
y
):
# weight = e^(-alpha * time_diff)
# # weight = e^(-alpha * time_diff)
# 相差1s的系数为0.9
# # 相差1s的系数为0.9
alpha
=
0.11
# alpha = 0.11
return
1
/
(
alpha
*
(
x
+
y
)
+
1
)
# return 1 / (alpha * (x + y) + 1)
return
1.0
# 目前不考虑时间系数
# 检查句子中的每个单词是否都是停用词
# 检查句子中的每个单词是否都是停用词
def
is_all_stopwords
(
sentence
):
def
is_all_stopwords
(
sentence
):
sentence
=
sentence
.
replace
(
' '
,
''
)
return
all
(
word
in
stop_words
for
word
in
sentence
)
return
all
(
word
in
stop_words
for
word
in
sentence
)
# 编辑距离算法 有问题!!!!!!
def
String_edit_distance
(
str1
,
str2
):
n
,
m
=
len
(
str1
),
len
(
str2
)
dp
=
[[
0
for
_
in
range
(
m
+
1
)]
for
_
in
range
(
n
+
1
)]
for
i
in
range
(
n
+
1
):
dp
[
i
][
0
]
=
i
for
i
in
range
(
m
+
1
):
dp
[
0
][
i
]
=
i
dp
[
0
][
0
]
=
0
for
i
in
range
(
1
,
n
+
1
):
for
j
in
range
(
1
,
m
+
1
):
if
str1
[
i
-
1
]
==
str2
[
j
-
1
]:
dp
[
i
][
j
]
=
dp
[
i
-
1
][
j
-
1
]
else
:
dp
[
i
][
j
]
=
min
(
dp
[
i
-
1
][
j
-
1
],
min
(
dp
[
i
][
j
-
1
],
dp
[
i
-
1
][
j
]))
+
1
# print(dp[n][m], n, m)
return
1.0
*
dp
[
n
][
m
]
/
max
(
n
,
m
)
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
### 如果其中有-符号,可能在用excel打开时自动添加=变成公式,读取的时候没问题
def
read_srt_to_csv
(
path_srt
,
path_output
):
def
read_srt_to_csv
(
path_srt
,
path_output
):
with
open
(
path_srt
,
'r'
,
encoding
=
'utf-8-sig'
)
as
f
:
with
open
(
path_srt
,
'r'
,
encoding
=
'utf-8-sig'
)
as
f
:
...
@@ -70,20 +107,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
...
@@ -70,20 +107,19 @@ def read_from_xlsx(path_xlsx='output.xlsx', path_output='deal.csv'):
csv_writer
.
writerow
(
title
)
csv_writer
.
writerow
(
title
)
for
_
,
data1
in
data
.
iterrows
():
for
_
,
data1
in
data
.
iterrows
():
start
,
end
,
subtitle
=
data1
[
1
],
data1
[
3
],
data1
[
4
]
# print(data1[1])
start
,
end
,
subtitle
=
data1
[
0
],
data1
[
1
],
data1
[
2
]
if
isinstance
(
subtitle
,
float
)
and
np
.
isnan
(
subtitle
):
if
isinstance
(
subtitle
,
float
)
and
np
.
isnan
(
subtitle
):
continue
continue
# 与srt文件格式同步
# 与srt文件格式同步
start
=
start
.
replace
(
'.'
,
','
)
start
=
start
.
replace
(
'.'
,
','
)
end
=
end
.
replace
(
'.'
,
','
)
end
=
end
.
replace
(
'.'
,
','
)
# print(start, end, subtitle,)
# print(type(start), type(end), type(subtitle))
csv_writer
.
writerow
([
start
,
end
,
subtitle
.
strip
()])
csv_writer
.
writerow
([
start
,
end
,
subtitle
.
strip
()])
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 对于srt中的字幕计算相似性度。从ocr中找到时间戳满足<=time_t的字幕,
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
### 然后计算字幕间的相似度,取一个最大的。字幕从start和end都匹配一遍
# time_threshold设置阈值,用于判断时间差是否可接受
# time_threshold设置阈值,用于判断时间差是否可接受
def
measure_score
(
path_srt
,
path_ocr
,
time_threshold
=
5.0
,
method
=
'cosine'
):
def
measure_score
(
path_srt
,
path_ocr
,
time_threshold
=
5.0
,
time_threshold_re
=
False
,
method
=
'cosine'
):
data_srt
,
data_ocr
=
[],
[]
data_srt
,
data_ocr
=
[],
[]
with
open
(
path_srt
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
with
open
(
path_srt
,
'r'
,
encoding
=
'utf-8'
)
as
file
:
csv_reader
=
csv
.
reader
(
file
)
csv_reader
=
csv
.
reader
(
file
)
...
@@ -105,20 +141,29 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
...
@@ -105,20 +141,29 @@ def measure_score(path_srt, path_ocr, time_threshold=5.0, method='cosine'):
total_weight
=
0.0
total_weight
=
0.0
for
sub
in
data_srt
:
for
sub
in
data_srt
:
max_similarity
=
0.0
max_similarity
=
0.0
if
method
!=
'distance'
else
-
1000.0
# 去除srt中的停用词
# 去除srt中的停用词
if
is_all_stopwords
(
sub
[
2
]):
if
is_all_stopwords
(
sub
[
2
]):
continue
continue
# subb = ""
for
sub1
in
data_ocr
:
for
sub1
in
data_ocr
:
x
,
y
=
abs
(
sub
[
0
]
-
sub1
[
0
]),
abs
(
sub
[
1
]
-
sub1
[
1
])
x
,
y
=
abs
(
sub
[
0
]
-
sub1
[
0
]),
abs
(
sub
[
1
]
-
sub1
[
1
])
if
min
(
x
,
y
)
<=
time_threshold
:
if
time_threshold_re
:
# print(sub[2], sub1[2])
time_threshold_tmp
=
time_threshold
score
=
calculate_similarity
(
sub
[
2
],
sub1
[
2
],
'cosine'
)
else
:
time_threshold_tmp
=
(
sub
[
1
]
-
sub
[
0
])
*
0.3
# 10s允许3s的误差
if
min
(
x
,
y
)
<=
time_threshold_tmp
:
score
=
calculate_similarity
(
sub
[
2
],
sub1
[
2
],
method
)
# if max_similarity <= score * calculate_weight(x, y):
# subb = sub1[2]
max_similarity
=
max
(
max_similarity
,
score
*
calculate_weight
(
x
,
y
))
max_similarity
=
max
(
max_similarity
,
score
*
calculate_weight
(
x
,
y
))
# if max_similarity <= 0.5:
# print(max_similarity, sub[2], subb, sub[0])
total_similarity
+=
max_similarity
total_similarity
+=
max_similarity
total_weight
+=
1
total_weight
+=
1
if
method
==
'distance'
:
total_similarity
=
total_weight
+
total_similarity
# print(total_similarity, total_similarity / len(data_srt), total_similarity / total_weight)
return
total_similarity
/
len
(
data_srt
),
total_similarity
/
total_weight
return
total_similarity
/
len
(
data_srt
),
total_similarity
/
total_weight
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -128,13 +173,19 @@ if __name__ == '__main__':
...
@@ -128,13 +173,19 @@ if __name__ == '__main__':
# 添加命令行参数
# 添加命令行参数
parser
.
add_argument
(
"--path_srt"
,
required
=
True
,
type
=
str
,
help
=
"Path of srt file, format is srt"
)
parser
.
add_argument
(
"--path_srt"
,
required
=
True
,
type
=
str
,
help
=
"Path of srt file, format is srt"
)
parser
.
add_argument
(
"--path_ocr"
,
required
=
True
,
type
=
str
,
help
=
"Path of ocr file, format is xlsx"
)
parser
.
add_argument
(
"--path_ocr"
,
required
=
True
,
type
=
str
,
help
=
"Path of ocr file, format is xlsx"
)
parser
.
add_argument
(
"--method"
,
type
=
str
,
default
=
'cosine'
,
help
=
"Select evaluation method"
)
parser
.
add_argument
(
"--time_threshold"
,
type
=
float
,
default
=
5.0
,
help
=
"Allowable time frame"
)
parser
.
add_argument
(
"--time_threshold"
,
type
=
float
,
default
=
5.0
,
help
=
"Allowable time frame"
)
parser
.
add_argument
(
"--method"
,
type
=
str
,
default
=
'distance'
,
choices
=
[
'cosine'
,
'distance'
,
'sequence'
]
,
help
=
"Select evaluation method"
)
parser
.
add_argument
(
"--time_threshold_re"
,
type
=
bool
,
default
=
True
,
help
=
"Specify whether
\
time threshold is required"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
output_file_srt
=
'deal_srt.csv'
output_file_srt
=
'deal_srt.csv'
output_file_ocr
=
'deal_ocr.csv'
output_file_ocr
=
'deal_ocr.csv'
read_srt_to_csv
(
args
.
path_srt
,
output_file_srt
)
read_srt_to_csv
(
args
.
path_srt
,
output_file_srt
)
read_from_xlsx
(
args
.
path_ocr
,
output_file_ocr
)
read_from_xlsx
(
args
.
path_ocr
,
output_file_ocr
)
score
=
measure_score
(
output_file_srt
,
output_file_ocr
,
args
.
time_threshold
,
args
.
method
)
score
=
measure_score
(
output_file_srt
,
output_file_ocr
,
args
.
time_threshold
,
\
args
.
time_threshold_re
,
args
.
method
)
print
(
f
'该评估算法得分: {score[1]:.5f}'
)
print
(
f
'该评估算法得分: {score[1]:.5f}'
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment