完成推理流程

This commit is contained in:
zhangkeyang 2024-04-21 00:02:47 +08:00
parent 74e58fd4bb
commit 206e4ec367
5 changed files with 737 additions and 15 deletions

View File

@ -1,14 +1,25 @@
import os import os
import torch
class ServerConfig(object): class ServerConfig(object):
# 视频文件配置 # 视频文件配置
video_root_path : str = os.path.join('data', 'video') videos_root_path : str = os.path.join('data', 'videos')
video_play_path : str = os.path.join(video_root_path, 'play') videos_play_path : str = os.path.join(videos_root_path, 'play')
video_upload_path : str = os.path.join(video_root_path, 'upload') videos_upload_path : str = os.path.join(videos_root_path, 'upload')
video_chunk_size : int = 1024 * 1024 * 1024 videos_chunk_size : int = 1024 * 1024 * 1024
# 数据文件配置
hdf5_root_path : str = os.path.join('data', 'h5s')
# TODO # TODO
class ModelConfig(object):
class DSNetConfig(object):
device : str = 'cuda' if torch.cuda.is_available() else 'cpu'
random_seed : int = 4906
sample_rate : int = 15
nms_thresh : float = 0.5
lambda_center : float = 1.
# TODO # TODO
pass pass

415
dsnet.py Normal file
View File

@ -0,0 +1,415 @@
import math
import torch
import torch.nn as nn
import cv2 as cv
import numpy as np
import torchvision
from pathlib import Path
from PIL import Image
from torch import Tensor
from torch.nn import Module
from numpy import ndarray, linalg
from torchvision import transforms
from ortools.algorithms.pywrapknapsack_solver import KnapsackSolver
from config import DSNetConfig
from typing import Iterable
class FeatureExtractor(object):
def __init__(self):
super(FeatureExtractor, self).__init__()
self.preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
self.model = torchvision.models.googlenet(pretrained=True)
self.model = torch.nn.Sequential(*list(self.model.children())[:-2])
self.model = self.model.to(DSNetConfig.device).eval()
def run(self, img: ndarray) -> ndarray:
img : Image = Image.fromarray(img)
tensor : Tensor = self.preprocess(img)
batch : Tensor = tensor.unsqueeze(0)
with torch.no_grad():
feat : Tensor|ndarray|None = None
feat = self.model(batch.to(DSNetConfig.device))
feat = feat.squeeze().cpu().numpy()
assert feat.shape == (1024,), f'Invalid feature shape {feat.shape}: expected 1024'
feat /= linalg.norm(feat) + 1e-10
return feat
class VideoPreprocessor(object):
def __init__(self) -> None:
super(VideoPreprocessor, self).__init__()
self.model = FeatureExtractor()
self.sample_rate = DSNetConfig.sample_rate
def get_features(self, video_path: str) -> tuple[int, ndarray]:
video_path = Path(video_path)
video_capture = cv.VideoCapture(str(video_path))
assert video_capture is not None, f'Cannot open video: {video_path}'
features: list[ndarray] = []
n_frames: int = 0
while True:
ret, frame = video_capture.read()
if not ret:
break
if n_frames % self.sample_rate == 0:
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
feat = self.model.run(frame)
features.append(feat)
n_frames += 1
video_capture.release()
features = np.array(features)
return n_frames, features
def calculate_scatters(K: ndarray) -> ndarray:
n = K.shape[0]
K1 = np.cumsum([0] + list(np.diag(K)))
K2 = np.zeros((n + 1, n + 1))
K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1)
diagK2 = np.diag(K2)
i = np.arange(n).reshape((-1, 1))
j = np.arange(n).reshape((1, -1))
scatters = (
K1[1:].reshape((1, -1)) - K1[:-1].reshape((-1, 1)) -
(diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape((-1, 1)) -
K2[1:, :-1].T - K2[:-1, 1:]) /
((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32))
)
scatters[j < i] = 0
return scatters
def change_point_detect_nonlin(K: ndarray, ncp: int, lmin: int = 1, lmax: int = 100000, backtrack=True) -> tuple[ndarray, ndarray]:
m = int(ncp)
n, n1 = K.shape
assert n == n1, 'Kernel matrix awaited.'
assert (m + 1) * lmin <= n <= (m + 1) * lmax
assert 1 <= lmin <= lmax
J = VideoPreprocessor.calculate_scatters(K)
I = 1e101 * np.ones((m + 1, n + 1))
I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1]
p = np.zeros((m + 1, n + 1), dtype=int) if backtrack else np.zeros((1, 1), dtype=int)
for k in range(1, m + 1): # k: 当前向视频中插入了k个变化点, 即将视频分为了(k + 1)段
for l in range((k + 1) * lmin, n + 1): # l: 当序列中出现了k个变化点后, 下一个段的最小起始位置, 也即是当前段的结束位置
tmin = max(k * lmin, l - lmax) # tmin: 现有的k个变化点至少使用了k * lmin个帧, 即当前段的最小起始位置
tmax = l - lmin + 1 #
c = J[tmin:tmax, l - 1].reshape(-1) + \
I[k - 1, tmin:tmax].reshape(-1)
I[k, l] = np.min(c)
if backtrack:
p[k, l] = np.argmin(c) + tmin
cps = np.zeros(m, dtype=int)
if backtrack:
cur = n
for k in range(m, 0, -1):
cps[k - 1] = p[k, cur]
cur = cps[k - 1]
scores = I[:, n].copy()
scores[scores > 1e99] = np.inf
return cps, scores
def change_point_detect_auto(K: ndarray) -> tuple[ndarray, ndarray]:
m, N = len(K) - 1, len(K)
_, scores = VideoPreprocessor.change_point_detect_nonlin(K, m, backtrack=False)
penalties = np.zeros(m + 1)
ncp = np.arange(1, m + 1)
penalties[1:] = (ncp / (2.0 * N)) * (np.log(float(N) / ncp) + 1)
costs = scores / float(N) + penalties
m_best = np.argmin(costs)
return VideoPreprocessor.change_point_detect_nonlin(K, m_best)
def kernel_temporal_segment(self, n_frames: int, features: ndarray) -> tuple[ndarray, ndarray, ndarray]:
seq_len = len(features)
picks = np.arange(0, seq_len) * self.sample_rate
kernel = np.matmul(features, features.T)
change_points, _ = VideoPreprocessor.change_point_detect_auto(kernel)
change_points *= self.sample_rate
change_points = np.hstack((0, change_points, n_frames))
begin_frames = change_points[:-1]
end_frames = change_points[1:]
change_points = np.vstack((begin_frames, end_frames - 1)).T
n_frame_per_seg = end_frames - begin_frames
return change_points, n_frame_per_seg, picks
def run(self, video_path: str) -> tuple[int, ndarray, ndarray, ndarray, ndarray]:
n_frames, features = self.get_features(video_path)
cps, nfps, picks = self.kernel_temporal_segment(n_frames, features)
return n_frames, features, cps, nfps, picks
class ScaledDotProductAttention(Module):
def __init__(self, d_k: float):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(0.5)
self.sqrt_d_k = math.sqrt(d_k)
def forward(self, Q: Tensor, K: Tensor, V: Tensor) -> tuple[Tensor, Tensor]:
attn = torch.bmm(Q, K.transpose(2, 1))
attn = attn / self.sqrt_d_k
attn = torch.softmax(attn, dim=-1)
attn = self.dropout(attn)
y = torch.bmm(attn, V)
return y, attn
class MultiHeadAttention(nn.Module):
def __init__(self, num_head: int, num_feature: int) -> None:
super(MultiHeadAttention, self).__init__()
self.num_head = num_head
self.Q = nn.Linear(num_feature, num_feature, bias=False)
self.K = nn.Linear(num_feature, num_feature, bias=False)
self.V = nn.Linear(num_feature, num_feature, bias=False)
self.d_k = num_feature // num_head
self.attention = ScaledDotProductAttention(self.d_k)
self.fc = nn.Sequential(
nn.Linear(num_feature, num_feature, bias=False),
nn.Dropout(0.5)
)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
_, seq_len, num_feature = x.shape # [1, seq_len, 1024]
K: Tensor = self.K(x) # [1, seq_len, 1024]
Q: Tensor = self.Q(x) # [1, seq_len, 1024]
V: Tensor = self.V(x) # [1, seq_len, 1024]
K = K.view(1, seq_len, self.num_head, self.d_k).permute(
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
Q = Q.view(1, seq_len, self.num_head, self.d_k).permute(
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
V = V.view(1, seq_len, self.num_head, self.d_k).permute(
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
y, attn = self.attention(Q, K, V) # [num_head, seq_len, d_k]
y = y.view(1, self.num_head, seq_len, self.d_k).permute(
0, 2, 1, 3).contiguous().view(1, seq_len, num_feature)
y = self.fc(y)
return y, attn
class AttentionExtractor(MultiHeadAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, *inputs):
out, _ = super().forward(*inputs)
return out
class DSNetAF(Module):
def __init__(self, num_feature: int, num_hidden: int, num_head: int) -> None:
super(DSNetAF, self).__init__()
self.base_model = AttentionExtractor(num_head, num_feature)
self.layer_norm = nn.LayerNorm(num_feature)
self.fc1 = nn.Sequential(
nn.Linear(num_feature, num_hidden),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.LayerNorm(num_hidden),
)
self.fc_cls = nn.Linear(num_hidden, 1)
self.fc_loc = nn.Linear(num_hidden, 2)
self.fc_ctr = nn.Linear(num_hidden, 1)
def offset2bbox(offsets: np.ndarray) -> np.ndarray:
offset_left, offset_right = offsets[:, 0], offsets[:, 1]
seq_len, _ = offsets.shape
indices = np.arange(seq_len)
bbox_left = indices - offset_left
bbox_right = indices + offset_right + 1
bboxes = np.vstack((bbox_left, bbox_right)).T
return bboxes
def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
_, seq_len, _ = x.shape
out = self.base_model(x)
out = out + x
out = self.layer_norm(out)
out = self.fc1(out)
pred_cls = self.fc_cls(out).sigmoid().view(seq_len)
pred_loc = self.fc_loc(out).exp().view(seq_len, 2)
pred_ctr = self.fc_ctr(out).sigmoid().view(seq_len)
return pred_cls, pred_loc, pred_ctr
def predict(self, seq: Tensor) -> tuple[ndarray, ndarray]:
pred_cls, pred_loc, pred_ctr = self(seq)
pred_cls *= pred_ctr
pred_cls /= pred_cls.max() + 1e-8
pred_cls = pred_cls.cpu().numpy()
pred_loc = pred_loc.cpu().numpy()
pred_bboxes = DSNetAF.offset2bbox(pred_loc)
return pred_cls, pred_bboxes
def iou_lr(anchor_bbox: np.ndarray, target_bbox: np.ndarray) -> np.ndarray:
anchor_left, anchor_right = anchor_bbox[:, 0], anchor_bbox[:, 1]
target_left, target_right = target_bbox[:, 0], target_bbox[:, 1]
inter_left = np.maximum(anchor_left, target_left)
inter_right = np.minimum(anchor_right, target_right)
union_left = np.minimum(anchor_left, target_left)
union_right = np.maximum(anchor_right, target_right)
intersect = inter_right - inter_left
intersect[intersect < 0] = 0
union = union_right - union_left
union[union <= 0] = 1e-6
iou = intersect / union
return iou
def nms(scores: np.ndarray, bboxes: np.ndarray, thresh: float) -> tuple[np.ndarray, np.ndarray]:
valid_idx = bboxes[:, 0] < bboxes[:, 1]
scores = scores[valid_idx]
bboxes = bboxes[valid_idx]
arg_desc = scores.argsort()[::-1]
scores_remain = scores[arg_desc]
bboxes_remain = bboxes[arg_desc]
keep_bboxes = []
keep_scores = []
while bboxes_remain.size > 0:
bbox = bboxes_remain[0]
score = scores_remain[0]
keep_bboxes.append(bbox)
keep_scores.append(score)
iou = iou_lr(bboxes_remain, np.expand_dims(bbox, axis=0))
keep_indices = (iou < thresh)
bboxes_remain = bboxes_remain[keep_indices]
scores_remain = scores_remain[keep_indices]
keep_bboxes = np.asarray(keep_bboxes, dtype=bboxes.dtype)
keep_scores = np.asarray(keep_scores, dtype=scores.dtype)
return keep_scores, keep_bboxes
def knapsack(values: Iterable[int],
weights: Iterable[int],
capacity: int
) -> list[int]:
knapsack_solver = KnapsackSolver(
KnapsackSolver.KNAPSACK_DYNAMIC_PROGRAMMING_SOLVER, 'test'
)
values = list(values)
weights = list(weights)
capacity = int(capacity)
knapsack_solver.Init(values, [weights], [capacity])
knapsack_solver.Solve()
packed_items = [x for x in range(0, len(weights))
if knapsack_solver.BestSolutionContains(x)]
return packed_items
def get_keyshot_summ(pred: np.ndarray,
cps: np.ndarray,
n_frames: int,
nfps: np.ndarray,
picks: np.ndarray,
proportion: float = 0.15
) -> np.ndarray:
assert pred.shape == picks.shape
picks = np.asarray(picks, dtype=np.int32)
# Get original frame scores from downsampled sequence
frame_scores = np.zeros(n_frames, dtype=np.float32)
for i in range(len(picks)):
pos_lo = picks[i]
pos_hi = picks[i + 1] if i + 1 < len(picks) else n_frames
frame_scores[pos_lo:pos_hi] = pred[i]
# Assign scores to video shots as the average of the frames.
seg_scores = np.zeros(len(cps), dtype=np.int32)
for seg_idx, (first, last) in enumerate(cps):
scores = frame_scores[first:last + 1]
seg_scores[seg_idx] = int(1000 * scores.mean())
# Apply knapsack algorithm to find the best shots
limits = int(n_frames * proportion)
packed = knapsack(seg_scores, nfps, limits)
# Get key-shot based summary
summary = np.zeros(n_frames, dtype=np.bool_)
for seg_idx in packed:
first, last = cps[seg_idx]
summary[first:last + 1] = True
return summary
def bbox2summary(seq_len: int,
pred_cls: np.ndarray,
pred_bboxes: np.ndarray,
change_points: np.ndarray,
n_frames: int,
nfps: np.ndarray,
picks: np.ndarray
) -> np.ndarray:
score = np.zeros(seq_len, dtype=np.float32)
for bbox_idx in range(len(pred_bboxes)):
lo, hi = pred_bboxes[bbox_idx, 0], pred_bboxes[bbox_idx, 1]
score[lo:hi] = np.maximum(score[lo:hi], [pred_cls[bbox_idx]])
pred_summ = get_keyshot_summ(score, change_points, n_frames, nfps, picks)
return pred_summ
video_preprocessor = VideoPreprocessor()
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
def extract(video_path: str) -> Tensor:
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_path)
with torch.no_grad():
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
cap = cv.VideoCapture(video_path)
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv.CAP_PROP_FPS)
# create summary video writer
fourcc = cv.VideoWriter.fourcc(*'mp4v')
out = cv.VideoWriter(args.save_path, fourcc, fps, (width, height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if pred_summ[frame_idx]:
out.write(frame)
frame_idx += 1
out.release()
cap.release()

93
main.py
View File

@ -1,12 +1,15 @@
import os import os
import h5py
import uvicorn import uvicorn
import ffmpeg
from pathlib import Path from pathlib import Path
from fastapi import FastAPI, UploadFile, File, Header from fastapi import FastAPI, UploadFile, Form, File, Header
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from config import ServerConfig from config import ServerConfig
from dsnet import *
app = FastAPI() app = FastAPI()
@ -20,21 +23,26 @@ app.add_middleware(
) )
video_preprocessor = VideoPreprocessor()
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
@app.post('/upload') @app.post('/upload')
async def upload_file(file: UploadFile = File(...)): async def upload_file(file: UploadFile = File(...)):
video_path = os.path.join(ServerConfig.video_upload_path, file.filename + '.mp4') video_name = os.path.splitext(file.filename)[0]
video_path = os.path.join(ServerConfig.videos_upload_path, video_name + '.mp4')
with open(video_path, 'wb') as video: with open(video_path, 'wb') as video:
video.write(await file.read()) video.write(await file.read())
return {'status': 200, 'id': os.path.splitext(file.filename)[0]} return {'status': 200, 'id': video_name}
@app.get("/fetch/{id}") @app.get("/fetch/{id}")
async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse: async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
video_path = Path(os.path.join(ServerConfig.video_play_path, id + '.mp4')) video_path = Path(os.path.join(ServerConfig.videos_play_path, id + '.mp4'))
video_size = video_path.stat().st_size video_size = video_path.stat().st_size
start = 0 start = 0
end = video_size - 1 end = video_size - 1
chunk_size = min(video_size, ServerConfig.video_chunk_size) chunk_size = min(video_size, ServerConfig.videos_chunk_size)
headers = { headers = {
'Content-Type': 'video/mp4', 'Content-Type': 'video/mp4',
'Content-Disposition': f'attachment; filename="{id}.mp4"', 'Content-Disposition': f'attachment; filename="{id}.mp4"',
@ -43,7 +51,7 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
start, end = range.replace('bytes=', '').split('-') start, end = range.replace('bytes=', '').split('-')
start = int(start) start = int(start)
end = int(end) if end else video_size - 1 end = int(end) if end else video_size - 1
chunk_size = min(end - start + 1, ServerConfig.video_chunk_size) chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size)
headers['Content-Range'] = f'bytes {start}-{end}/{video_size}' headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
headers['Accept-Ranges'] = 'bytes' headers['Accept-Ranges'] = 'bytes'
headers['Content-Disposition'] = 'inline' headers['Content-Disposition'] = 'inline'
@ -58,9 +66,80 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
return StreamingResponse(file_reader(), status_code=206, headers=headers, media_type='video/mp4') return StreamingResponse(file_reader(), status_code=206, headers=headers, media_type='video/mp4')
@app.post('/extract')
async def extract_file(id: str = Form(...)):
video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
data_target_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
with h5py.File(data_target_path, 'a') as data:
if id not in data:
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path)
data.create_dataset(f'{id}/features', data=seq)
data.create_dataset(f'{id}/change_points', data=cps)
data.create_dataset(f'{id}/n_frame_per_seg', data=nfps)
data.create_dataset(f'{id}/n_frames', data=n_frames)
data.create_dataset(f'{id}/picks', data=picks)
return {'status': 200, 'id': id}
@app.post('/analyse')
async def analyse_file(id: str = Form(...)):
data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
pred_summ = None
with h5py.File(data_path, 'r') as data:
if 'vsumm' not in data[id]:
seq = data[id]['features'][...].astype(np.float32)
cps = data[id]['change_points'][...].astype(np.int32)
nfps = data[id]['n_frame_per_seg'][...].astype(np.int32)
n_frames = data[id]['n_frames'][...].astype(np.int32)
picks = data[id]['picks'][...].astype(np.int32)
with torch.no_grad():
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
if pred_summ is not None:
with h5py.File(data_path, 'a') as data:
data.create_dataset(f'{id}/vsumm', data=pred_summ)
return {'status': 200, 'id': id}
@app.post('/generate')
async def generate_file(id: str = Form(...)):
origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4')
final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4')
data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
with h5py.File(data_path, 'r') as data:
vsumm = data[id]['vsumm'][...].astype(np.bool_)
cap = cv.VideoCapture(origin_video_path)
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv.CAP_PROP_FPS)
# create summary video writer
fourcc = cv.VideoWriter.fourcc(*'H264')
out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if vsumm[frame_idx]:
out.write(frame)
frame_idx += 1
out.release()
cap.release()
ffmpeg.input(target_video_path, format='mp4', vcodec='mpeg4').output(final_video_path, format='mp4', vcodec='h264').run()
return {'status': 200, 'id': id}
if __name__ == "__main__": if __name__ == "__main__":
uvicorn.run(app=app) fourcc = cv.VideoWriter.fourcc(*'H264')
out = cv.VideoWriter('123.mp4', fourcc, 30, (1024, 768))
pass
#uvicorn.run(app=app)

View File

@ -1,5 +1,6 @@
mypy==1.9.0 mypy==1.9.0
numpy==1.22.0 h5py==3.11.0
numpy==1.26.4
torch==2.1.0 torch==2.1.0
torchvision==0.16.0 torchvision==0.16.0
opencv-python==4.9.0.80 opencv-python==4.9.0.80
@ -7,3 +8,5 @@ opencv-contrib-python==4.9.0.80
uvicorn==0.29.0 uvicorn==0.29.0
fastapi==0.110.2 fastapi==0.110.2
python-multipart==0.0.9 python-multipart==0.0.9
ortools==9.9.396
ffmpeg-python==0.2.0

View File

@ -59,7 +59,7 @@ function onFileUploadStepped(index, percent) {
* *
* 1) 将表格 status-table 中状态栏的信息改为 '正在提取' * 1) 将表格 status-table 中状态栏的信息改为 '正在提取'
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%' * 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
* 3) 调用服务器接口, 开始执行推理 * 3) 调用服务器接口, 开始执行提取
* *
* @param {number} index 文件的索引 * @param {number} index 文件的索引
* @param {File} file 文件对象 * @param {File} file 文件对象
@ -69,7 +69,7 @@ function onFileUploadFinished(index, file, data) {
$(`#stat-label-${index}`).html('正在提取'); $(`#stat-label-${index}`).html('正在提取');
$(`#stat-prog-${index}`).css({'width': '0%'}); $(`#stat-prog-${index}`).css({'width': '0%'});
$(`#stat-prog-${index}`).html('0%'); $(`#stat-prog-${index}`).html('0%');
extractFile(index, file, data.id);
} }
/** /**
@ -97,6 +97,134 @@ function onFileUploadFailed(index, file, err) {
} }
/** /**
* 当文件提取成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息到h5文件中.
*
* 1) 将表格 status-table 中进度栏的进度信息更新为50%
* 2) 调用服务器接口, 开始执行提取关键帧
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {object} data 服务器响应数据
*/
function onExtractFileFinished(index, file, data) {
$(`#stat-label-${index}`).html('正在分析');
$(`#stat-prog-${index}`).css({'width': '50%'});
$(`#stat-prog-${index}`).html('50%');
analyseFile(index, file, data.id);
}
/**
* 当文件提取失败后, 此函数将被回调, 用于更新部分界面组件.
*
* 1) 将表格 status-table 中的状态栏信息改为 '提取失败'.
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {error} err 失败原因
*/
function onExtractFileFailed(index, file, err) {
$(`#stat-label-${index}`).html('<div class="text-danger">提取失败</div>');
$(`#stat-prog-${index}`).css({'width': '0%'});
$(`#stat-prog-${index}`).html('0%');
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
$(`#stat-button-${index}`).addClass('btn-danger');
$(`#stat-button-${index}`).find('span').remove();
$(`#stat-button-${index}`).text('查看');
$(`#stat-button-${index}`).on('click', function() {
errorModal('错误', `文件${file.name}提取失败: ${err}`);
});
}
/**
* 当文件分析成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息和分析结果到h5文件中.
*
* 1) 将表格 status-table 中进度栏的进度信息更新为75%
* 2) 调用服务器接口, 生成摘要视频
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {object} data 服务器响应数据
*/
function onAnalyseFileFinished(index, file, data) {
$(`#stat-label-${index}`).html('正在导出');
$(`#stat-prog-${index}`).css({'width': '50%'});
$(`#stat-prog-${index}`).html('50%');
generateFile(index, file, data.id);
}
/**
* 当文件分析失败后, 此函数将被回调, 用于更新部分界面组件.
*
* 1) 将表格 status-table 中的状态栏信息改为 '分析失败'.
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {error} err 失败原因
*/
function onAnalyseFileFailed(index, file, err) {
$(`#stat-label-${index}`).html('<div class="text-danger">分析失败</div>');
$(`#stat-prog-${index}`).css({'width': '0%'});
$(`#stat-prog-${index}`).html('0%');
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
$(`#stat-button-${index}`).addClass('btn-danger');
$(`#stat-button-${index}`).find('span').remove();
$(`#stat-button-${index}`).text('查看');
$(`#stat-button-${index}`).on('click', function() {
errorModal('错误', `文件${file.name}分析失败: ${err}`);
});
}
/**
* 当文件生成成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的摘要.
*
* 1) 将表格 status-table 中进度栏的进度信息更新为100%
* 2) 调用服务器接口, 生成摘要视频
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {object} data 服务器响应数据
*/
function onGenerateFileFinished(index, file, data) {
$(`#stat-label-${index}`).html('<div class="text-success">操作完成</div>');
$(`#stat-prog-${index}`).css({'width': '100%'});
$(`#stat-prog-${index}`).html('100%');
//TODO:
}
/**
* 当文件生成失败后, 此函数将被回调, 用于更新部分界面组件.
*
* 1) 将表格 status-table 中的状态栏信息改为 '生成失败'.
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
*
* @param {number} index 文件的索引
* @param {File} file 文件对象
* @param {error} err 失败原因
*/
function onGenerateFileFailed(index, file, err) {
$(`#stat-label-${index}`).html('<div class="text-danger">生成失败</div>');
$(`#stat-prog-${index}`).css({'width': '0%'});
$(`#stat-prog-${index}`).html('0%');
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
$(`#stat-button-${index}`).addClass('btn-danger');
$(`#stat-button-${index}`).find('span').remove();
$(`#stat-button-${index}`).text('查看');
$(`#stat-button-${index}`).on('click', function() {
errorModal('错误', `文件${file.name}生成失败: ${err}`);
});
}
/**
* 上传指定的文件到服务器的接口. * 上传指定的文件到服务器的接口.
* *
* 此函数会回调 onFileUploadStepped 函数, 告知文件的上传进度. * 此函数会回调 onFileUploadStepped 函数, 告知文件的上传进度.
@ -148,6 +276,8 @@ function uploadFile(index, file) {
/** /**
* 在结果窗口中播放服务器中编号为id的视频. * 在结果窗口中播放服务器中编号为id的视频.
* *
* @param {index} index 文件的索引
* @param {file} file 文件对象
* @param {string} id 服务器视频id * @param {string} id 服务器视频id
*/ */
function playFile(index, file, id) { function playFile(index, file, id) {
@ -161,6 +291,90 @@ function playFile(index, file, id) {
} }
/** /**
* 让服务器提取目标文件的特征.
*
* @param {index} index 文件的索引
* @param {file} file 文件对象
* @param {string} id 服务器视频id
*/
function extractFile(index, file, id) {
var formData = new FormData();
formData.append('id', id);
$.ajax({
url: 'http://127.0.0.1:8000/extract',
type: 'POST',
data: formData,
processData: false, // 不处理发送的数据
contentType: false, // 不设置内容类型
success: function(data) {
if (data.status != 200) {
//TODO: 添加错误处理
}
onExtractFileFinished(index, file, data);
},
error: function(xhr, status, error) {
onExtractFileFailed(index, file, error);
},
});
}
/**
* 让服务器分析目标文件的关键帧.
*
* @param {index} index 文件的索引
* @param {file} file 文件对象
* @param {string} id 服务器视频id
*/
function analyseFile(index, file, id) {
var formData = new FormData();
formData.append('id', id);
$.ajax({
url: 'http://127.0.0.1:8000/analyse',
type: 'POST',
data: formData,
processData: false, // 不处理发送的数据
contentType: false, // 不设置内容类型
success: function(data) {
if (data.status != 200) {
//TODO: 添加错误处理
}
onAnalyseFileFinished(index, file, data);
},
error: function(xhr, status, error) {
onAnalyseFileFailed(index, file, error);
},
});
}
/**
* 让服务器生成目标文件的摘要视频.
*
* @param {index} index 文件的索引
* @param {file} file 文件对象
* @param {string} id 服务器视频id
*/
function generateFile(index, file, id) {
var formData = new FormData();
formData.append('id', id);
$.ajax({
url: 'http://127.0.0.1:8000/generate',
type: 'POST',
data: formData,
processData: false, // 不处理发送的数据
contentType: false, // 不设置内容类型
success: function(data) {
if (data.status != 200) {
//TODO: 添加错误处理
}
onGenerateFileFinished(index, file, data);
},
error: function(xhr, status, error) {
onGenerateFileFailed(index, file, error);
},
});
}
/**
* 向表格'status-table'中增加一行, 向用户展示上传文件的信息与当前状态, 并提供必要的操作. * 向表格'status-table'中增加一行, 向用户展示上传文件的信息与当前状态, 并提供必要的操作.
* *
* @param {number} index 正在处理的文件编号. * @param {number} index 正在处理的文件编号.