完成推理流程
This commit is contained in:
parent
74e58fd4bb
commit
206e4ec367
21
config.py
21
config.py
|
@ -1,14 +1,25 @@
|
|||
import os
|
||||
import torch
|
||||
|
||||
|
||||
class ServerConfig(object):
|
||||
# 视频文件配置
|
||||
video_root_path : str = os.path.join('data', 'video')
|
||||
video_play_path : str = os.path.join(video_root_path, 'play')
|
||||
video_upload_path : str = os.path.join(video_root_path, 'upload')
|
||||
video_chunk_size : int = 1024 * 1024 * 1024
|
||||
videos_root_path : str = os.path.join('data', 'videos')
|
||||
videos_play_path : str = os.path.join(videos_root_path, 'play')
|
||||
videos_upload_path : str = os.path.join(videos_root_path, 'upload')
|
||||
videos_chunk_size : int = 1024 * 1024 * 1024
|
||||
# 数据文件配置
|
||||
hdf5_root_path : str = os.path.join('data', 'h5s')
|
||||
|
||||
# TODO
|
||||
|
||||
class ModelConfig(object):
|
||||
|
||||
class DSNetConfig(object):
|
||||
device : str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
random_seed : int = 4906
|
||||
sample_rate : int = 15
|
||||
nms_thresh : float = 0.5
|
||||
lambda_center : float = 1.
|
||||
# TODO
|
||||
pass
|
||||
|
||||
|
|
|
@ -0,0 +1,415 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import cv2 as cv
|
||||
import numpy as np
|
||||
import torchvision
|
||||
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
from numpy import ndarray, linalg
|
||||
from torchvision import transforms
|
||||
from ortools.algorithms.pywrapknapsack_solver import KnapsackSolver
|
||||
from config import DSNetConfig
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
class FeatureExtractor(object):
|
||||
|
||||
def __init__(self):
|
||||
super(FeatureExtractor, self).__init__()
|
||||
self.preprocess = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
])
|
||||
self.model = torchvision.models.googlenet(pretrained=True)
|
||||
self.model = torch.nn.Sequential(*list(self.model.children())[:-2])
|
||||
self.model = self.model.to(DSNetConfig.device).eval()
|
||||
|
||||
def run(self, img: ndarray) -> ndarray:
|
||||
img : Image = Image.fromarray(img)
|
||||
tensor : Tensor = self.preprocess(img)
|
||||
batch : Tensor = tensor.unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
feat : Tensor|ndarray|None = None
|
||||
feat = self.model(batch.to(DSNetConfig.device))
|
||||
feat = feat.squeeze().cpu().numpy()
|
||||
|
||||
assert feat.shape == (1024,), f'Invalid feature shape {feat.shape}: expected 1024'
|
||||
feat /= linalg.norm(feat) + 1e-10
|
||||
return feat
|
||||
|
||||
|
||||
class VideoPreprocessor(object):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super(VideoPreprocessor, self).__init__()
|
||||
self.model = FeatureExtractor()
|
||||
self.sample_rate = DSNetConfig.sample_rate
|
||||
|
||||
def get_features(self, video_path: str) -> tuple[int, ndarray]:
|
||||
video_path = Path(video_path)
|
||||
video_capture = cv.VideoCapture(str(video_path))
|
||||
assert video_capture is not None, f'Cannot open video: {video_path}'
|
||||
|
||||
features: list[ndarray] = []
|
||||
n_frames: int = 0
|
||||
while True:
|
||||
ret, frame = video_capture.read()
|
||||
if not ret:
|
||||
break
|
||||
if n_frames % self.sample_rate == 0:
|
||||
frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
|
||||
feat = self.model.run(frame)
|
||||
features.append(feat)
|
||||
n_frames += 1
|
||||
video_capture.release()
|
||||
features = np.array(features)
|
||||
return n_frames, features
|
||||
|
||||
def calculate_scatters(K: ndarray) -> ndarray:
|
||||
n = K.shape[0]
|
||||
K1 = np.cumsum([0] + list(np.diag(K)))
|
||||
K2 = np.zeros((n + 1, n + 1))
|
||||
K2[1:, 1:] = np.cumsum(np.cumsum(K, 0), 1)
|
||||
|
||||
diagK2 = np.diag(K2)
|
||||
|
||||
i = np.arange(n).reshape((-1, 1))
|
||||
j = np.arange(n).reshape((1, -1))
|
||||
scatters = (
|
||||
K1[1:].reshape((1, -1)) - K1[:-1].reshape((-1, 1)) -
|
||||
(diagK2[1:].reshape((1, -1)) + diagK2[:-1].reshape((-1, 1)) -
|
||||
K2[1:, :-1].T - K2[:-1, 1:]) /
|
||||
((j - i + 1).astype(np.float32) + (j == i - 1).astype(np.float32))
|
||||
)
|
||||
scatters[j < i] = 0
|
||||
return scatters
|
||||
|
||||
def change_point_detect_nonlin(K: ndarray, ncp: int, lmin: int = 1, lmax: int = 100000, backtrack=True) -> tuple[ndarray, ndarray]:
|
||||
m = int(ncp)
|
||||
n, n1 = K.shape
|
||||
assert n == n1, 'Kernel matrix awaited.'
|
||||
assert (m + 1) * lmin <= n <= (m + 1) * lmax
|
||||
assert 1 <= lmin <= lmax
|
||||
|
||||
J = VideoPreprocessor.calculate_scatters(K)
|
||||
I = 1e101 * np.ones((m + 1, n + 1))
|
||||
I[0, lmin:lmax] = J[0, lmin - 1:lmax - 1]
|
||||
p = np.zeros((m + 1, n + 1), dtype=int) if backtrack else np.zeros((1, 1), dtype=int)
|
||||
|
||||
for k in range(1, m + 1): # k: 当前向视频中插入了k个变化点, 即将视频分为了(k + 1)段
|
||||
for l in range((k + 1) * lmin, n + 1): # l: 当序列中出现了k个变化点后, 下一个段的最小起始位置, 也即是当前段的结束位置
|
||||
tmin = max(k * lmin, l - lmax) # tmin: 现有的k个变化点,至少使用了k * lmin个帧, 即当前段的最小起始位置
|
||||
tmax = l - lmin + 1 #
|
||||
c = J[tmin:tmax, l - 1].reshape(-1) + \
|
||||
I[k - 1, tmin:tmax].reshape(-1)
|
||||
I[k, l] = np.min(c)
|
||||
if backtrack:
|
||||
p[k, l] = np.argmin(c) + tmin
|
||||
|
||||
cps = np.zeros(m, dtype=int)
|
||||
if backtrack:
|
||||
cur = n
|
||||
for k in range(m, 0, -1):
|
||||
cps[k - 1] = p[k, cur]
|
||||
cur = cps[k - 1]
|
||||
scores = I[:, n].copy()
|
||||
scores[scores > 1e99] = np.inf
|
||||
return cps, scores
|
||||
|
||||
def change_point_detect_auto(K: ndarray) -> tuple[ndarray, ndarray]:
|
||||
m, N = len(K) - 1, len(K)
|
||||
_, scores = VideoPreprocessor.change_point_detect_nonlin(K, m, backtrack=False)
|
||||
|
||||
penalties = np.zeros(m + 1)
|
||||
ncp = np.arange(1, m + 1)
|
||||
penalties[1:] = (ncp / (2.0 * N)) * (np.log(float(N) / ncp) + 1)
|
||||
|
||||
costs = scores / float(N) + penalties
|
||||
m_best = np.argmin(costs)
|
||||
return VideoPreprocessor.change_point_detect_nonlin(K, m_best)
|
||||
|
||||
def kernel_temporal_segment(self, n_frames: int, features: ndarray) -> tuple[ndarray, ndarray, ndarray]:
|
||||
seq_len = len(features)
|
||||
picks = np.arange(0, seq_len) * self.sample_rate
|
||||
|
||||
kernel = np.matmul(features, features.T)
|
||||
change_points, _ = VideoPreprocessor.change_point_detect_auto(kernel)
|
||||
change_points *= self.sample_rate
|
||||
change_points = np.hstack((0, change_points, n_frames))
|
||||
begin_frames = change_points[:-1]
|
||||
end_frames = change_points[1:]
|
||||
change_points = np.vstack((begin_frames, end_frames - 1)).T
|
||||
|
||||
n_frame_per_seg = end_frames - begin_frames
|
||||
return change_points, n_frame_per_seg, picks
|
||||
|
||||
def run(self, video_path: str) -> tuple[int, ndarray, ndarray, ndarray, ndarray]:
|
||||
n_frames, features = self.get_features(video_path)
|
||||
cps, nfps, picks = self.kernel_temporal_segment(n_frames, features)
|
||||
return n_frames, features, cps, nfps, picks
|
||||
|
||||
|
||||
class ScaledDotProductAttention(Module):
|
||||
def __init__(self, d_k: float):
|
||||
super(ScaledDotProductAttention, self).__init__()
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.sqrt_d_k = math.sqrt(d_k)
|
||||
|
||||
def forward(self, Q: Tensor, K: Tensor, V: Tensor) -> tuple[Tensor, Tensor]:
|
||||
attn = torch.bmm(Q, K.transpose(2, 1))
|
||||
attn = attn / self.sqrt_d_k
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
attn = self.dropout(attn)
|
||||
y = torch.bmm(attn, V)
|
||||
return y, attn
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, num_head: int, num_feature: int) -> None:
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
self.num_head = num_head
|
||||
self.Q = nn.Linear(num_feature, num_feature, bias=False)
|
||||
self.K = nn.Linear(num_feature, num_feature, bias=False)
|
||||
self.V = nn.Linear(num_feature, num_feature, bias=False)
|
||||
self.d_k = num_feature // num_head
|
||||
self.attention = ScaledDotProductAttention(self.d_k)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(num_feature, num_feature, bias=False),
|
||||
nn.Dropout(0.5)
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
|
||||
_, seq_len, num_feature = x.shape # [1, seq_len, 1024]
|
||||
K: Tensor = self.K(x) # [1, seq_len, 1024]
|
||||
Q: Tensor = self.Q(x) # [1, seq_len, 1024]
|
||||
V: Tensor = self.V(x) # [1, seq_len, 1024]
|
||||
|
||||
K = K.view(1, seq_len, self.num_head, self.d_k).permute(
|
||||
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
|
||||
Q = Q.view(1, seq_len, self.num_head, self.d_k).permute(
|
||||
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
|
||||
V = V.view(1, seq_len, self.num_head, self.d_k).permute(
|
||||
2, 0, 1, 3).contiguous().view(self.num_head, seq_len, self.d_k)
|
||||
|
||||
y, attn = self.attention(Q, K, V) # [num_head, seq_len, d_k]
|
||||
y = y.view(1, self.num_head, seq_len, self.d_k).permute(
|
||||
0, 2, 1, 3).contiguous().view(1, seq_len, num_feature)
|
||||
|
||||
y = self.fc(y)
|
||||
return y, attn
|
||||
|
||||
|
||||
class AttentionExtractor(MultiHeadAttention):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, *inputs):
|
||||
out, _ = super().forward(*inputs)
|
||||
return out
|
||||
|
||||
|
||||
class DSNetAF(Module):
|
||||
|
||||
def __init__(self, num_feature: int, num_hidden: int, num_head: int) -> None:
|
||||
super(DSNetAF, self).__init__()
|
||||
self.base_model = AttentionExtractor(num_head, num_feature)
|
||||
self.layer_norm = nn.LayerNorm(num_feature)
|
||||
self.fc1 = nn.Sequential(
|
||||
nn.Linear(num_feature, num_hidden),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Dropout(0.5),
|
||||
nn.LayerNorm(num_hidden),
|
||||
)
|
||||
self.fc_cls = nn.Linear(num_hidden, 1)
|
||||
self.fc_loc = nn.Linear(num_hidden, 2)
|
||||
self.fc_ctr = nn.Linear(num_hidden, 1)
|
||||
|
||||
def offset2bbox(offsets: np.ndarray) -> np.ndarray:
|
||||
offset_left, offset_right = offsets[:, 0], offsets[:, 1]
|
||||
seq_len, _ = offsets.shape
|
||||
indices = np.arange(seq_len)
|
||||
bbox_left = indices - offset_left
|
||||
bbox_right = indices + offset_right + 1
|
||||
bboxes = np.vstack((bbox_left, bbox_right)).T
|
||||
return bboxes
|
||||
|
||||
def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]:
|
||||
_, seq_len, _ = x.shape
|
||||
out = self.base_model(x)
|
||||
out = out + x
|
||||
out = self.layer_norm(out)
|
||||
out = self.fc1(out)
|
||||
pred_cls = self.fc_cls(out).sigmoid().view(seq_len)
|
||||
pred_loc = self.fc_loc(out).exp().view(seq_len, 2)
|
||||
pred_ctr = self.fc_ctr(out).sigmoid().view(seq_len)
|
||||
return pred_cls, pred_loc, pred_ctr
|
||||
|
||||
def predict(self, seq: Tensor) -> tuple[ndarray, ndarray]:
|
||||
pred_cls, pred_loc, pred_ctr = self(seq)
|
||||
pred_cls *= pred_ctr
|
||||
pred_cls /= pred_cls.max() + 1e-8
|
||||
pred_cls = pred_cls.cpu().numpy()
|
||||
pred_loc = pred_loc.cpu().numpy()
|
||||
pred_bboxes = DSNetAF.offset2bbox(pred_loc)
|
||||
return pred_cls, pred_bboxes
|
||||
|
||||
|
||||
def iou_lr(anchor_bbox: np.ndarray, target_bbox: np.ndarray) -> np.ndarray:
|
||||
anchor_left, anchor_right = anchor_bbox[:, 0], anchor_bbox[:, 1]
|
||||
target_left, target_right = target_bbox[:, 0], target_bbox[:, 1]
|
||||
|
||||
inter_left = np.maximum(anchor_left, target_left)
|
||||
inter_right = np.minimum(anchor_right, target_right)
|
||||
union_left = np.minimum(anchor_left, target_left)
|
||||
union_right = np.maximum(anchor_right, target_right)
|
||||
|
||||
intersect = inter_right - inter_left
|
||||
intersect[intersect < 0] = 0
|
||||
union = union_right - union_left
|
||||
union[union <= 0] = 1e-6
|
||||
|
||||
iou = intersect / union
|
||||
return iou
|
||||
|
||||
|
||||
def nms(scores: np.ndarray, bboxes: np.ndarray, thresh: float) -> tuple[np.ndarray, np.ndarray]:
|
||||
valid_idx = bboxes[:, 0] < bboxes[:, 1]
|
||||
scores = scores[valid_idx]
|
||||
bboxes = bboxes[valid_idx]
|
||||
|
||||
arg_desc = scores.argsort()[::-1]
|
||||
|
||||
scores_remain = scores[arg_desc]
|
||||
bboxes_remain = bboxes[arg_desc]
|
||||
|
||||
keep_bboxes = []
|
||||
keep_scores = []
|
||||
|
||||
while bboxes_remain.size > 0:
|
||||
bbox = bboxes_remain[0]
|
||||
score = scores_remain[0]
|
||||
keep_bboxes.append(bbox)
|
||||
keep_scores.append(score)
|
||||
|
||||
iou = iou_lr(bboxes_remain, np.expand_dims(bbox, axis=0))
|
||||
|
||||
keep_indices = (iou < thresh)
|
||||
bboxes_remain = bboxes_remain[keep_indices]
|
||||
scores_remain = scores_remain[keep_indices]
|
||||
|
||||
keep_bboxes = np.asarray(keep_bboxes, dtype=bboxes.dtype)
|
||||
keep_scores = np.asarray(keep_scores, dtype=scores.dtype)
|
||||
return keep_scores, keep_bboxes
|
||||
|
||||
|
||||
def knapsack(values: Iterable[int],
|
||||
weights: Iterable[int],
|
||||
capacity: int
|
||||
) -> list[int]:
|
||||
|
||||
knapsack_solver = KnapsackSolver(
|
||||
KnapsackSolver.KNAPSACK_DYNAMIC_PROGRAMMING_SOLVER, 'test'
|
||||
)
|
||||
|
||||
values = list(values)
|
||||
weights = list(weights)
|
||||
capacity = int(capacity)
|
||||
|
||||
knapsack_solver.Init(values, [weights], [capacity])
|
||||
knapsack_solver.Solve()
|
||||
packed_items = [x for x in range(0, len(weights))
|
||||
if knapsack_solver.BestSolutionContains(x)]
|
||||
|
||||
return packed_items
|
||||
|
||||
|
||||
def get_keyshot_summ(pred: np.ndarray,
|
||||
cps: np.ndarray,
|
||||
n_frames: int,
|
||||
nfps: np.ndarray,
|
||||
picks: np.ndarray,
|
||||
proportion: float = 0.15
|
||||
) -> np.ndarray:
|
||||
assert pred.shape == picks.shape
|
||||
picks = np.asarray(picks, dtype=np.int32)
|
||||
|
||||
# Get original frame scores from downsampled sequence
|
||||
frame_scores = np.zeros(n_frames, dtype=np.float32)
|
||||
for i in range(len(picks)):
|
||||
pos_lo = picks[i]
|
||||
pos_hi = picks[i + 1] if i + 1 < len(picks) else n_frames
|
||||
frame_scores[pos_lo:pos_hi] = pred[i]
|
||||
|
||||
# Assign scores to video shots as the average of the frames.
|
||||
seg_scores = np.zeros(len(cps), dtype=np.int32)
|
||||
for seg_idx, (first, last) in enumerate(cps):
|
||||
scores = frame_scores[first:last + 1]
|
||||
seg_scores[seg_idx] = int(1000 * scores.mean())
|
||||
|
||||
# Apply knapsack algorithm to find the best shots
|
||||
limits = int(n_frames * proportion)
|
||||
packed = knapsack(seg_scores, nfps, limits)
|
||||
|
||||
# Get key-shot based summary
|
||||
summary = np.zeros(n_frames, dtype=np.bool_)
|
||||
for seg_idx in packed:
|
||||
first, last = cps[seg_idx]
|
||||
summary[first:last + 1] = True
|
||||
return summary
|
||||
|
||||
|
||||
def bbox2summary(seq_len: int,
|
||||
pred_cls: np.ndarray,
|
||||
pred_bboxes: np.ndarray,
|
||||
change_points: np.ndarray,
|
||||
n_frames: int,
|
||||
nfps: np.ndarray,
|
||||
picks: np.ndarray
|
||||
) -> np.ndarray:
|
||||
score = np.zeros(seq_len, dtype=np.float32)
|
||||
for bbox_idx in range(len(pred_bboxes)):
|
||||
lo, hi = pred_bboxes[bbox_idx, 0], pred_bboxes[bbox_idx, 1]
|
||||
score[lo:hi] = np.maximum(score[lo:hi], [pred_cls[bbox_idx]])
|
||||
|
||||
pred_summ = get_keyshot_summ(score, change_points, n_frames, nfps, picks)
|
||||
return pred_summ
|
||||
|
||||
|
||||
video_preprocessor = VideoPreprocessor()
|
||||
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
|
||||
|
||||
|
||||
def extract(video_path: str) -> Tensor:
|
||||
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_path)
|
||||
with torch.no_grad():
|
||||
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
|
||||
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
|
||||
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
|
||||
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
|
||||
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
|
||||
cap = cv.VideoCapture(video_path)
|
||||
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv.CAP_PROP_FPS)
|
||||
|
||||
# create summary video writer
|
||||
fourcc = cv.VideoWriter.fourcc(*'mp4v')
|
||||
out = cv.VideoWriter(args.save_path, fourcc, fps, (width, height))
|
||||
|
||||
frame_idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if pred_summ[frame_idx]:
|
||||
out.write(frame)
|
||||
frame_idx += 1
|
||||
out.release()
|
||||
cap.release()
|
||||
|
93
main.py
93
main.py
|
@ -1,12 +1,15 @@
|
|||
import os
|
||||
import h5py
|
||||
import uvicorn
|
||||
import ffmpeg
|
||||
|
||||
from pathlib import Path
|
||||
from fastapi import FastAPI, UploadFile, File, Header
|
||||
from fastapi import FastAPI, UploadFile, Form, File, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from config import ServerConfig
|
||||
from dsnet import *
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
@ -20,21 +23,26 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
video_preprocessor = VideoPreprocessor()
|
||||
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
|
||||
|
||||
|
||||
@app.post('/upload')
|
||||
async def upload_file(file: UploadFile = File(...)):
|
||||
video_path = os.path.join(ServerConfig.video_upload_path, file.filename + '.mp4')
|
||||
video_name = os.path.splitext(file.filename)[0]
|
||||
video_path = os.path.join(ServerConfig.videos_upload_path, video_name + '.mp4')
|
||||
with open(video_path, 'wb') as video:
|
||||
video.write(await file.read())
|
||||
return {'status': 200, 'id': os.path.splitext(file.filename)[0]}
|
||||
return {'status': 200, 'id': video_name}
|
||||
|
||||
|
||||
@app.get("/fetch/{id}")
|
||||
async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
|
||||
video_path = Path(os.path.join(ServerConfig.video_play_path, id + '.mp4'))
|
||||
video_path = Path(os.path.join(ServerConfig.videos_play_path, id + '.mp4'))
|
||||
video_size = video_path.stat().st_size
|
||||
start = 0
|
||||
end = video_size - 1
|
||||
chunk_size = min(video_size, ServerConfig.video_chunk_size)
|
||||
chunk_size = min(video_size, ServerConfig.videos_chunk_size)
|
||||
headers = {
|
||||
'Content-Type': 'video/mp4',
|
||||
'Content-Disposition': f'attachment; filename="{id}.mp4"',
|
||||
|
@ -43,7 +51,7 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
|
|||
start, end = range.replace('bytes=', '').split('-')
|
||||
start = int(start)
|
||||
end = int(end) if end else video_size - 1
|
||||
chunk_size = min(end - start + 1, ServerConfig.video_chunk_size)
|
||||
chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size)
|
||||
headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
|
||||
headers['Accept-Ranges'] = 'bytes'
|
||||
headers['Content-Disposition'] = 'inline'
|
||||
|
@ -58,9 +66,80 @@ async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
|
|||
return StreamingResponse(file_reader(), status_code=206, headers=headers, media_type='video/mp4')
|
||||
|
||||
|
||||
@app.post('/extract')
|
||||
async def extract_file(id: str = Form(...)):
|
||||
video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
|
||||
data_target_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
|
||||
with h5py.File(data_target_path, 'a') as data:
|
||||
if id not in data:
|
||||
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path)
|
||||
data.create_dataset(f'{id}/features', data=seq)
|
||||
data.create_dataset(f'{id}/change_points', data=cps)
|
||||
data.create_dataset(f'{id}/n_frame_per_seg', data=nfps)
|
||||
data.create_dataset(f'{id}/n_frames', data=n_frames)
|
||||
data.create_dataset(f'{id}/picks', data=picks)
|
||||
return {'status': 200, 'id': id}
|
||||
|
||||
|
||||
@app.post('/analyse')
|
||||
async def analyse_file(id: str = Form(...)):
|
||||
data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
|
||||
pred_summ = None
|
||||
with h5py.File(data_path, 'r') as data:
|
||||
if 'vsumm' not in data[id]:
|
||||
seq = data[id]['features'][...].astype(np.float32)
|
||||
cps = data[id]['change_points'][...].astype(np.int32)
|
||||
nfps = data[id]['n_frame_per_seg'][...].astype(np.int32)
|
||||
n_frames = data[id]['n_frames'][...].astype(np.int32)
|
||||
picks = data[id]['picks'][...].astype(np.int32)
|
||||
with torch.no_grad():
|
||||
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
|
||||
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
|
||||
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
|
||||
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
|
||||
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
|
||||
if pred_summ is not None:
|
||||
with h5py.File(data_path, 'a') as data:
|
||||
data.create_dataset(f'{id}/vsumm', data=pred_summ)
|
||||
return {'status': 200, 'id': id}
|
||||
|
||||
|
||||
@app.post('/generate')
|
||||
async def generate_file(id: str = Form(...)):
|
||||
origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
|
||||
target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4')
|
||||
final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4')
|
||||
data_path = os.path.join(ServerConfig.hdf5_root_path, id + '.h5')
|
||||
with h5py.File(data_path, 'r') as data:
|
||||
vsumm = data[id]['vsumm'][...].astype(np.bool_)
|
||||
cap = cv.VideoCapture(origin_video_path)
|
||||
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
|
||||
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
|
||||
fps = cap.get(cv.CAP_PROP_FPS)
|
||||
|
||||
# create summary video writer
|
||||
fourcc = cv.VideoWriter.fourcc(*'H264')
|
||||
out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height))
|
||||
|
||||
frame_idx = 0
|
||||
while True:
|
||||
ret, frame = cap.read()
|
||||
if not ret:
|
||||
break
|
||||
if vsumm[frame_idx]:
|
||||
out.write(frame)
|
||||
|
||||
frame_idx += 1
|
||||
out.release()
|
||||
cap.release()
|
||||
|
||||
ffmpeg.input(target_video_path, format='mp4', vcodec='mpeg4').output(final_video_path, format='mp4', vcodec='h264').run()
|
||||
return {'status': 200, 'id': id}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app=app)
|
||||
fourcc = cv.VideoWriter.fourcc(*'H264')
|
||||
out = cv.VideoWriter('123.mp4', fourcc, 30, (1024, 768))
|
||||
pass
|
||||
#uvicorn.run(app=app)
|
||||
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
mypy==1.9.0
|
||||
numpy==1.22.0
|
||||
h5py==3.11.0
|
||||
numpy==1.26.4
|
||||
torch==2.1.0
|
||||
torchvision==0.16.0
|
||||
opencv-python==4.9.0.80
|
||||
|
@ -7,3 +8,5 @@ opencv-contrib-python==4.9.0.80
|
|||
uvicorn==0.29.0
|
||||
fastapi==0.110.2
|
||||
python-multipart==0.0.9
|
||||
ortools==9.9.396
|
||||
ffmpeg-python==0.2.0
|
||||
|
|
|
@ -59,7 +59,7 @@ function onFileUploadStepped(index, percent) {
|
|||
*
|
||||
* 1) 将表格 status-table 中状态栏的信息改为 '正在提取'
|
||||
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
|
||||
* 3) 调用服务器接口, 开始执行推理
|
||||
* 3) 调用服务器接口, 开始执行提取
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
|
@ -69,7 +69,7 @@ function onFileUploadFinished(index, file, data) {
|
|||
$(`#stat-label-${index}`).html('正在提取');
|
||||
$(`#stat-prog-${index}`).css({'width': '0%'});
|
||||
$(`#stat-prog-${index}`).html('0%');
|
||||
|
||||
extractFile(index, file, data.id);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -97,6 +97,134 @@ function onFileUploadFailed(index, file, err) {
|
|||
}
|
||||
|
||||
/**
|
||||
* 当文件提取成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息到h5文件中.
|
||||
*
|
||||
* 1) 将表格 status-table 中进度栏的进度信息更新为50%
|
||||
* 2) 调用服务器接口, 开始执行提取关键帧
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {object} data 服务器响应数据
|
||||
*/
|
||||
function onExtractFileFinished(index, file, data) {
|
||||
$(`#stat-label-${index}`).html('正在分析');
|
||||
$(`#stat-prog-${index}`).css({'width': '50%'});
|
||||
$(`#stat-prog-${index}`).html('50%');
|
||||
analyseFile(index, file, data.id);
|
||||
}
|
||||
|
||||
/**
|
||||
* 当文件提取失败后, 此函数将被回调, 用于更新部分界面组件.
|
||||
*
|
||||
* 1) 将表格 status-table 中的状态栏信息改为 '提取失败'.
|
||||
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
|
||||
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {error} err 失败原因
|
||||
*/
|
||||
function onExtractFileFailed(index, file, err) {
|
||||
$(`#stat-label-${index}`).html('<div class="text-danger">提取失败</div>');
|
||||
$(`#stat-prog-${index}`).css({'width': '0%'});
|
||||
$(`#stat-prog-${index}`).html('0%');
|
||||
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
|
||||
$(`#stat-button-${index}`).addClass('btn-danger');
|
||||
$(`#stat-button-${index}`).find('span').remove();
|
||||
$(`#stat-button-${index}`).text('查看');
|
||||
$(`#stat-button-${index}`).on('click', function() {
|
||||
errorModal('错误', `文件${file.name}提取失败: ${err}`);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 当文件分析成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的特征信息和分析结果到h5文件中.
|
||||
*
|
||||
* 1) 将表格 status-table 中进度栏的进度信息更新为75%
|
||||
* 2) 调用服务器接口, 生成摘要视频
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {object} data 服务器响应数据
|
||||
*/
|
||||
function onAnalyseFileFinished(index, file, data) {
|
||||
$(`#stat-label-${index}`).html('正在导出');
|
||||
$(`#stat-prog-${index}`).css({'width': '50%'});
|
||||
$(`#stat-prog-${index}`).html('50%');
|
||||
generateFile(index, file, data.id);
|
||||
}
|
||||
|
||||
/**
|
||||
* 当文件分析失败后, 此函数将被回调, 用于更新部分界面组件.
|
||||
*
|
||||
* 1) 将表格 status-table 中的状态栏信息改为 '分析失败'.
|
||||
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
|
||||
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {error} err 失败原因
|
||||
*/
|
||||
function onAnalyseFileFailed(index, file, err) {
|
||||
$(`#stat-label-${index}`).html('<div class="text-danger">分析失败</div>');
|
||||
$(`#stat-prog-${index}`).css({'width': '0%'});
|
||||
$(`#stat-prog-${index}`).html('0%');
|
||||
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
|
||||
$(`#stat-button-${index}`).addClass('btn-danger');
|
||||
$(`#stat-button-${index}`).find('span').remove();
|
||||
$(`#stat-button-${index}`).text('查看');
|
||||
$(`#stat-button-${index}`).on('click', function() {
|
||||
errorModal('错误', `文件${file.name}分析失败: ${err}`);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* 当文件生成成功后, 此函数将被回调. 此时, 服务器已经保存了视频文件的摘要.
|
||||
*
|
||||
* 1) 将表格 status-table 中进度栏的进度信息更新为100%
|
||||
* 2) 调用服务器接口, 生成摘要视频
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {object} data 服务器响应数据
|
||||
*/
|
||||
function onGenerateFileFinished(index, file, data) {
|
||||
$(`#stat-label-${index}`).html('<div class="text-success">操作完成</div>');
|
||||
$(`#stat-prog-${index}`).css({'width': '100%'});
|
||||
$(`#stat-prog-${index}`).html('100%');
|
||||
//TODO:
|
||||
}
|
||||
|
||||
/**
|
||||
* 当文件生成失败后, 此函数将被回调, 用于更新部分界面组件.
|
||||
*
|
||||
* 1) 将表格 status-table 中的状态栏信息改为 '生成失败'.
|
||||
* 2) 将表格 status-table 中进度栏的进度信息重置为 '0%'
|
||||
* 3) 将表格 status-table 中操作栏的按钮修改为 danger, 并激活以查看失败原因.
|
||||
*
|
||||
* @param {number} index 文件的索引
|
||||
* @param {File} file 文件对象
|
||||
* @param {error} err 失败原因
|
||||
*/
|
||||
function onGenerateFileFailed(index, file, err) {
|
||||
$(`#stat-label-${index}`).html('<div class="text-danger">生成失败</div>');
|
||||
$(`#stat-prog-${index}`).css({'width': '0%'});
|
||||
$(`#stat-prog-${index}`).html('0%');
|
||||
$(`#stat-button-${index}`).removeClass('btn-dark disabled');
|
||||
$(`#stat-button-${index}`).addClass('btn-danger');
|
||||
$(`#stat-button-${index}`).find('span').remove();
|
||||
$(`#stat-button-${index}`).text('查看');
|
||||
$(`#stat-button-${index}`).on('click', function() {
|
||||
errorModal('错误', `文件${file.name}生成失败: ${err}`);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 上传指定的文件到服务器的接口.
|
||||
*
|
||||
* 此函数会回调 onFileUploadStepped 函数, 告知文件的上传进度.
|
||||
|
@ -148,6 +276,8 @@ function uploadFile(index, file) {
|
|||
/**
|
||||
* 在结果窗口中播放服务器中编号为id的视频.
|
||||
*
|
||||
* @param {index} index 文件的索引
|
||||
* @param {file} file 文件对象
|
||||
* @param {string} id 服务器视频id
|
||||
*/
|
||||
function playFile(index, file, id) {
|
||||
|
@ -161,6 +291,90 @@ function playFile(index, file, id) {
|
|||
}
|
||||
|
||||
/**
|
||||
* 让服务器提取目标文件的特征.
|
||||
*
|
||||
* @param {index} index 文件的索引
|
||||
* @param {file} file 文件对象
|
||||
* @param {string} id 服务器视频id
|
||||
*/
|
||||
function extractFile(index, file, id) {
|
||||
var formData = new FormData();
|
||||
formData.append('id', id);
|
||||
$.ajax({
|
||||
url: 'http://127.0.0.1:8000/extract',
|
||||
type: 'POST',
|
||||
data: formData,
|
||||
processData: false, // 不处理发送的数据
|
||||
contentType: false, // 不设置内容类型
|
||||
success: function(data) {
|
||||
if (data.status != 200) {
|
||||
//TODO: 添加错误处理
|
||||
}
|
||||
onExtractFileFinished(index, file, data);
|
||||
},
|
||||
error: function(xhr, status, error) {
|
||||
onExtractFileFailed(index, file, error);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 让服务器分析目标文件的关键帧.
|
||||
*
|
||||
* @param {index} index 文件的索引
|
||||
* @param {file} file 文件对象
|
||||
* @param {string} id 服务器视频id
|
||||
*/
|
||||
function analyseFile(index, file, id) {
|
||||
var formData = new FormData();
|
||||
formData.append('id', id);
|
||||
$.ajax({
|
||||
url: 'http://127.0.0.1:8000/analyse',
|
||||
type: 'POST',
|
||||
data: formData,
|
||||
processData: false, // 不处理发送的数据
|
||||
contentType: false, // 不设置内容类型
|
||||
success: function(data) {
|
||||
if (data.status != 200) {
|
||||
//TODO: 添加错误处理
|
||||
}
|
||||
onAnalyseFileFinished(index, file, data);
|
||||
},
|
||||
error: function(xhr, status, error) {
|
||||
onAnalyseFileFailed(index, file, error);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 让服务器生成目标文件的摘要视频.
|
||||
*
|
||||
* @param {index} index 文件的索引
|
||||
* @param {file} file 文件对象
|
||||
* @param {string} id 服务器视频id
|
||||
*/
|
||||
function generateFile(index, file, id) {
|
||||
var formData = new FormData();
|
||||
formData.append('id', id);
|
||||
$.ajax({
|
||||
url: 'http://127.0.0.1:8000/generate',
|
||||
type: 'POST',
|
||||
data: formData,
|
||||
processData: false, // 不处理发送的数据
|
||||
contentType: false, // 不设置内容类型
|
||||
success: function(data) {
|
||||
if (data.status != 200) {
|
||||
//TODO: 添加错误处理
|
||||
}
|
||||
onGenerateFileFinished(index, file, data);
|
||||
},
|
||||
error: function(xhr, status, error) {
|
||||
onGenerateFileFailed(index, file, error);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 向表格'status-table'中增加一行, 向用户展示上传文件的信息与当前状态, 并提供必要的操作.
|
||||
*
|
||||
* @param {number} index 正在处理的文件编号.
|
||||
|
|
Loading…
Reference in New Issue