icdd-vsumm/main.py

145 lines
5.7 KiB
Python

import os
import h5py
import uvicorn
import ffmpeg
from pathlib import Path
from fastapi import FastAPI, UploadFile, Form, File, Header
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from config import ServerConfig
from dsnet import *
app = FastAPI()
app.mount('/home', StaticFiles(directory='static', html=True), name='static')
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
video_preprocessor = VideoPreprocessor()
dsnet_af = DSNetAF(1024, 128, 8).to(DSNetConfig.device)
dsnet_af.load_state_dict(torch.load(ServerConfig.weights_dsnet_af, map_location=lambda storage, loc: storage))
@app.post('/upload')
async def upload_file(file: UploadFile = File(...)):
video_name = os.path.splitext(file.filename)[0]
video_path = os.path.join(ServerConfig.videos_upload_path, video_name + '.mp4')
with open(video_path, 'wb') as video:
video.write(await file.read())
return {'status': 200, 'id': video_name}
@app.get("/fetch/{id}")
async def fetch_file(id: str, range: str = Header(None)) -> StreamingResponse:
video_path = Path(os.path.join(ServerConfig.videos_play_path, id + '.mp4'))
video_size = video_path.stat().st_size
start = 0
end = video_size - 1
chunk_size = min(video_size, ServerConfig.videos_chunk_size)
status = 200
headers = {
'Content-Type': 'video/mp4',
'Content-Disposition': f'attachment; filename="{id}.mp4"',
}
if range:
start, end = range.replace('bytes=', '').split('-')
start = int(start)
end = int(end) if end else video_size - 1
chunk_size = min(end - start + 1, ServerConfig.videos_chunk_size)
status = 206
headers['Content-Range'] = f'bytes {start}-{end}/{video_size}'
headers['Accept-Ranges'] = 'bytes'
headers['Content-Disposition'] = 'inline'
def file_reader():
with open(video_path, 'rb') as video:
video.seek(start)
while True:
data = video.read(chunk_size)
if not data:
break
yield data
return StreamingResponse(file_reader(), status_code=status, headers=headers, media_type='video/mp4')
@app.post('/extract')
async def extract_file(id: str = Form(...)):
video_source_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
data_target_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
with h5py.File(data_target_path, 'a') as data:
if id not in data:
n_frames, seq, cps, nfps, picks = video_preprocessor.run(video_source_path)
data.create_dataset(f'{id}/features', data=seq)
data.create_dataset(f'{id}/change_points', data=cps)
data.create_dataset(f'{id}/n_frame_per_seg', data=nfps)
data.create_dataset(f'{id}/n_frames', data=n_frames)
data.create_dataset(f'{id}/picks', data=picks)
return {'status': 200, 'id': id}
@app.post('/analyse')
async def analyse_file(id: str = Form(...)):
data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
pred_summ = None
with h5py.File(data_path, 'r') as data:
if 'vsumm' not in data[id]:
seq = data[id]['features'][...].astype(np.float32)
cps = data[id]['change_points'][...].astype(np.int32)
nfps = data[id]['n_frame_per_seg'][...].astype(np.int32)
n_frames = data[id]['n_frames'][...].astype(np.int32)
picks = data[id]['picks'][...].astype(np.int32)
with torch.no_grad():
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(DSNetConfig.device)
pred_cls, pred_bboxes = dsnet_af.predict(seq_torch)
pred_bboxes = np.clip(pred_bboxes, 0, len(seq)).round().astype(np.int32)
pred_cls, pred_bboxes = nms(pred_cls, pred_bboxes, DSNetConfig.nms_thresh)
pred_summ = bbox2summary(len(seq), pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
if pred_summ is not None:
with h5py.File(data_path, 'a') as data:
data.create_dataset(f'{id}/vsumm', data=pred_summ)
return {'status': 200, 'id': id}
@app.post('/generate')
async def generate_file(id: str = Form(...)):
origin_video_path = os.path.join(ServerConfig.videos_upload_path, id + '.mp4')
target_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mpeg4.mp4')
final_video_path = os.path.join(ServerConfig.videos_play_path, id + '.mp4')
data_path = os.path.join(ServerConfig.h5s_root_path, id + '.h5')
with h5py.File(data_path, 'r') as data:
vsumm = data[id]['vsumm'][...].astype(np.bool_)
cap = cv.VideoCapture(origin_video_path)
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv.CAP_PROP_FPS)
fourcc = cv.VideoWriter.fourcc(*'MP4V')
out = cv.VideoWriter(target_video_path, fourcc, fps, (width, height))
frame_idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
if vsumm[frame_idx]:
out.write(frame)
frame_idx += 1
out.release()
cap.release()
ffmpeg.input(target_video_path, format='mp4', vcodec='mpeg4').output(final_video_path, format='mp4', vcodec='h264').run()
return {'status': 200, 'id': id}
if __name__ == "__main__":
uvicorn.run(app=app)