#!/bin/env python3
#stream data via ipfs
#to stream: ffmpeg | ipfs_stream 
#to watch: ipfs_stream stream_id | ffplay
#copyright 2021 Russell Stickney
#
#modification and distribution of this program is allowed.
import base64
import json
import nacl.signing
import os
import re
import selectors
import subprocess
import sys
import time
import urllib.parse
import urllib.request
debug = 2
sel = selectors.DefaultSelector()
def default_stream():
	config = setup()
	stream(config['id'], config['sign_key'], sys.stdin.buffer, config['read_method'])
def decode_sign_message(verify_key, message):
	signature_bytes = base64.b64decode(message['sig'])
	data_bytes = message['data'].encode()
	data_verified = verify_key.verify(data_bytes, signature_bytes)
	debug and dprint('message ok', data_verified, message['data'])
	return data_verified and message['data']
	
def dprint(*plist, **kw):
	print(*plist, file=sys.stderr)
def get_linked_data(hash, output_flo):
	link_sequencer = sequence_event(sel, output_flo)
	def write_linked_data(segment):
		debug and dprint('fetching linked data', segment.data)
		message = json.loads(segment.data.decode())
		for index, link in enumerate(message['links']):
			segment_hash = link['Cid']['/']
			debug and dprint('getting segment:', index, segment_hash)
			link_sequencer.add_event('segment', str(index), segment_hash )
	debug and dprint('geting link node {}'.format(hash) )			
	link_segment = segment(sel, hash, write_linked_data)
	link_segment.start_segment()
	while link_segment.finished == False or link_sequencer.seq:
		events = sel.select(timeout=10)
		for key, mask in events:
			callback = key.data
			callback()
		debug and dprint('waiting on: link finished:', link_segment.finished , 'link sequence', len(link_sequencer.seq) )
		
	
def format_multipart_formdata(data):
	#unable to find multipartencoding in the stdlib
	#this is rough and ready but should do the trick
	#a working boundary is
	#--------------------------67f6eece09014f5e
	boundary = 'segment'
	name = 'segment'
	data_header = {
		'Content-Disposition': 'form-data; name="{}"; filename="{}"'.format(name, name),
		'Content-Type': 'application/octet-stream',
		}
	data_part = b''
	data_part += '--{}\r\n'.format(boundary).encode()
	for header in data_header:
		data_part += '{}: {}\r\n'.format(header, data_header[header]).encode()
	data_part += '\r\n'.encode()
	data_part += data
	data_part += '\r\n--{}--\r\n'.format(boundary).encode()
	multipart_headers = {
		'Content-Length': str(len(data_part)),
		'Content-Type': 'multipart/form-data; boundary={}'.format(boundary),
		}
	return data_part, multipart_headers
	
		
def ipfs_make_api(ipfs_loc, ipfs_port, api_prefix):
	ipfs_loc = '{}:{}'.format(ipfs_loc, ipfs_port)
	def ipfs_api(api_path, arg_map, data=None, callback=None):
		request_headers = None
		query = urllib.parse.urlencode(arg_map)
		url_parts = [
			'http',
			ipfs_loc,
			api_prefix + api_path,
			'',
			query,
			''
			]
		ipfs_url = urllib.parse.urlunparse(url_parts)
		debug > 1 and dprint('ipfs url:', ipfs_url)
		if data is not None:
			data, request_headers = format_multipart_formdata(data)
			debug > 1 and dprint('data:', data)
		ipfs_request = urllib.request.Request(url = ipfs_url, headers = request_headers, data = data)
		dprint('debug_request:', ipfs_request)
		with urllib.request.urlopen(ipfs_request) as ipfs_response:
			response_data = ipfs_response.read()
			debug > 1 and dprint(response_data)
		return response_data
	return ipfs_api
		
ipfs_api = ipfs_make_api(ipfs_loc = '127.0.0.1', ipfs_port = '5001', api_prefix = '/api/v0/')
		
def ipfs_publish(channel, id, sign_key, message):
	'''
		channel to publish message, usually video/id
		id username
		sign_key nacl signing key nacl.signing.SigningKey(key)
		message space seperated message
			segment index hash
			watching id
		'''
	publish_ok = False
	#sig = sign message bytes; base64 encode; decode as unidata string 
	data = {
		'id':id,
		'sig':base64.b64encode(sign_key.sign(message.encode()).signature).decode(),
		'data':message,
		}
	signed_json_message = json.dumps(data).encode() + b'\n'
	pub_result = ipfs_api('pubsub/pub', {'arg':channel}, data=signed_json_message)
	print('pub_result', pub_result)
	publish_ok = pub_results == b''
	ipfs_pub_err = pub_results or None
	return publish_ok, ipfs_pub_err
	
def load_ipfs_config():
	return json.load(open(os.path.join(os.environ['HOME'], '.ipfs', 'config'), 'r'))
def setup():
	config = {}
	if os.path.exists('config'):
		with open('config', 'r') as config_file:
			for line in config_file:
				data, sep, comment = line.partition('#')
				if data and ':' in data:
					key, sep, value = data.partition(':')
					config[key.strip()] = value.strip()
	ipfs_config = load_ipfs_config()
	id = ipfs_config['Identity']['PeerID']
	if not os.path.isdir('private'):
		os.mkdir('private')
	private_key_path = os.path.join('private', id)
	if not os.path.exists(private_key_path):
		with open(private_key_path, 'wb') as private_key_file:
			debug and dprint('generating new private key', private_key_path)
			sign_key = nacl.signing.SigningKey.generate()
			private_key_file.write(base64.b64encode(sign_key.encode()))
	else:
		with open(private_key_path, 'rb') as private_key_file:
			debug and dprint('loading private key', private_key_path)
			sign_key = nacl.signing.SigningKey(base64.b64decode(private_key_file.read()))
	
	print('using identity:', id)
	
	if not os.path.isdir('public'):
		os.mkdir('public')
	public_key_path = os.path.join('public', id)
	public_key = sign_key.verify_key
	b64_public_key = base64.b64encode(public_key.encode())
	if not os.path.exists(public_key_path):
		with open(public_key_path, 'wb') as public_key_file:
			public_key_file.write(b64_public_key)
	print('public key:', b64_public_key.decode())
	read_method_map = {
		'bytes': make_segment_size,
		'seconds': make_segment_time,
		}
	info = {
		'id': id,
		'sign_key': sign_key,
		'read_method': make_segment_size(2**18),
		}
	if 'segment every' in config:
		read_amount, sep, read_method = config['segment every'].partition(' ')
		read_amount = read_amount.strip()
		read_method = read_method.strip()
		if read_method in read_method_map and read_amount.isdigit():
			print('segmenting every', read_amount, read_method)
			info['read_method'] = read_method_map[read_method](int(read_amount))
		else:
			print('invalid "segment every": format is "segment every: amount type" where type is one of', ' '.join(read_method_map.keys()) )
	return info
		
class segment:
	def __init__(self, selector, address, finished_cb = None):
		self.selector = selector
		self.address = address
		self.data = b''
		self.finished = False
		self.finished_cb = finished_cb
		self.read_size = 100
		self.ipfs_proc = None
	def start_segment(self):
		ipfs_cmd = [
			'ipfs',
			'dag',
			'get',
			self.address,
			]
		self.ipfs_proc = subprocess.Popen(ipfs_cmd, stdout=subprocess.PIPE, encoding='utf-8')
		self.stream =  self.ipfs_proc.stdout.buffer
		os.set_blocking(self.stream.fileno(), False)
		self.selector.register(self.stream, selectors.EVENT_READ, self.read_data)
	def read_data(self):
		data = self.stream.read(self.read_size)
		if not data:
			self.stream.close()
			self.selector.unregister(self.stream)
			self.finished = True
			if self.finished_cb != None:
				self.finished_cb(self)
		else:
			while len(data) == self.read_size:
				self.data += data
				data = self.stream.read(self.read_size)
				if data == None:
					data = b''
			self.data += data
	def write(self, segment_file):
		message = json.loads(self.data.decode())
		if 'bytes' in message:
			based_data = message['bytes']
			if based_data[0] == 'm':
				bin_data = base64.b64decode(based_data[1:])
			else:
				dprint('error decoding data. not a known multibase type')
				bin_data = b''
			segment_file.write(bin_data)
		
class sequence_event:
	def __init__(self, selector, segment_file):
		self.selector = selector
		self.segment_file = segment_file
		self.seq = {}
		self.index = 0
		self.running_seq = {}
		self.max_run = 5
	def add_event(self, method, sequence, message):
		seq_ok = False
		if sequence.isdigit():
			seq_id = int(sequence)
			if method == 'segment':
				self.seq[seq_id] = segment(self.selector, message, self.cycle)	
				seq_ok = True
				if len(self.running_seq) < self.max_run:
					self.start_event(seq_id)
		return seq_ok
	def start_event(self, seq_id):
		self.running_seq[seq_id] = True
		self.seq[seq_id].start_segment()
	def cycle(self, overide = None):
		#note: this implamentation of cycle is not robustaly designed and should be reviewed
		debug and dprint('segment cycle')
		del_list = []
		#output all finished segments in order
		sorted_segment_list = sorted(self.seq)
		debug and dprint('segments active', ','.join(map(str, sorted_segment_list)) or 'None')
		#catch up to the current segment
		if sorted_segment_list and sorted_segment_list[0] > self.index:
			self.index = sorted_segment_list[0]
		#check all active segments in order, write the ones that are done to file and delete them
		#to do: some sort of timeout or skip so we do not get stuck on a missing segment
		#idea: have segment index be file offset, then we could write sparse files.
		for sindex in sorted_segment_list:
			debug and dprint('checking if segment is finished', sindex)
			if overide == True or (self.seq[sindex].finished and sindex == self.index):
				debug and dprint('segment ok', sindex)
				self.index += 1
				overide = False
			if sindex < self.index and self.seq[sindex].finished:
				debug and dprint('writing segment {}'.format(sindex) )
				self.seq[sindex].write(self.segment_file)
				del_list.append(sindex)
			elif (len(self.running_seq) - len(del_list) < self.max_run) and self.seq[sindex].ipfs_proc == None:
				self.start_event(sindex)
					
			else:
				debug and dprint('waiting for segment {} to finish'.format(self.index))
		for dindex in del_list:
			del self.seq[dindex]
			del self.running_seq[dindex]
			debug and dprint('removing segment', dindex )
			
class verify_message:
	''' message is a dictionary with the folling keys
		id: the id of the sender, must be present in 'public' directory
		data: string with message data
		sig: signature of data string
		'''
	def __init__(self):
		self.verify_key = {}
		self.stored_key = os.listdir('public')
		if debug:
			for key in self.stored_key:
				dprint('known key', key)
		self.valid_id_exp = re.compile(r'Qm[^ ]{40,50}$')
	def check(self, message):
		data_verified = False
		result = 'Never processed'
		message_id = message.get('id', 'unknown')
		if self.valid_id_exp.match(message_id):
			if message_id not in self.verify_key and message_id in self.stored_key:
				debug and dprint('loading key', message_id)
				with open(os.path.join('public', message_id), 'rb') as pub_file:
					self.verify_key[message_id] = nacl.signing.VerifyKey(base64.b64decode(pub_file.read()))
			if message_id in self.verify_key:
				signature_bytes = base64.b64decode(message['sig'])
				data_bytes = message['data'].encode()
				data_verified = self.verify_key[message_id].verify(data_bytes, signature_bytes)
				result = 'verified'
			else:
				result = 'No public key found for id "{}"'.format(message_id)
		else:
			result = 'message id failed to meet valid id requirments'
			
		return data_verified, result
class subscribe_stream:
	def __init__(self, stream_id, selector, segment_file, verify):
		self.stream_id = stream_id
		self.selector = selector
		self.segment_file = segment_file
		self.buffer = b''
		self.sequence = {
			self.stream_id: sequence_event(self.selector, self.segment_file),
			}
		self.verify = verify
		self.read_size = 100
		self.jd = json.JSONDecoder()
		user_info = setup()
		self.id = user_info['id']
		self.sign_key = user_info['sign_key']
		self.channel = 'video/{}'.format(stream_id)
		ipfs_sub_command = [
			'ipfs',
			'pubsub',
			'sub',
			self.channel,
			]
		self.ipfs_sub_proc = subprocess.Popen(ipfs_sub_command, stdout=subprocess.PIPE)
		self.proc_file = self.ipfs_sub_proc.stdout
		os.set_blocking(self.proc_file.fileno(), False)
		self.selector.register(self.proc_file, selectors.EVENT_READ, self.read_stream)
		self.send_message('watching {}'.format(stream_id))
	def read_stream(self):
		debug and dprint('reading stream data')
		read_data = self.proc_file.read(self.read_size)
		if not read_data:
			debug and dprint('stream {} over closing down'.format(self.stream_id) )
			self.proc_file.close()
			self.selector.unregister(self.proc_file)
		debug and dprint('read {} bytes'.format(len(read_data) ) )
		while len(read_data) == self.read_size:
			debug and dprint('reading more')
			self.buffer += read_data
			read_data = self.proc_file.read(self.read_size)
			if read_data == None:
				read_data = b''
			debug and dprint('read {} bytes'.format(len(read_data) ) )
		self.buffer += read_data
		check_failed = 0
		while check_failed < 1:
			try:
				message, break_index = self.jd.raw_decode(self.buffer.decode())
				self.buffer = self.buffer[break_index:]
				self.read_size = len(self.buffer) + break_index
				message_ok, message_ok_reason = self.verify.check(message)
				debug and dprint('message_ok', message_ok)
				if message_ok:
					sender = message['id']
					data_args = message['data'].split()
					if data_args:
						method = data_args[0]
						if method == 'segment':
							self.method_segment(sender = sender, sequence_id = data_args[1], segment_hash = data_args[2])
						else:
							dprint('unknown message:', sender, message['data'])
				else:
					debug and dprint('signature failure', message_ok_reason)
				check_failed = 0
			except json.decoder.JSONDecodeError:
				debug and dprint('no json', self.buffer)
				check_failed += 1
	def method_segment(self, sender, sequence_id, segment_hash):
		if sender == self.stream_id:
			self.sequence[sender].add_event('segment', sequence_id, segment_hash)
		else:
			debug and dprint('recieved segment from invalid sender:', sender, 'expected from:', self.stream_id, 'segment:', sequence_id, segment_hash)
			 
			
		
	def send_message(self, message):
		publish_ok, publish_err = ipfs_publish(self.channel, self.id, self.sign_key, message)
		if publish_ok == False:
			dprint('error publishing message:', publish_err, message)
		return publish_ok
		
					
			
def subscribe(stream_id):
	segment_file = sys.stdout.buffer
	verify = verify_message()
	subscription = subscribe_stream(stream_id, sel, segment_file, verify)
	while subscription.ipfs_sub_proc.poll() == None:
		events = sel.select(timeout=10)
		for key, mask in events:
			callback = key.data
			callback()
def make_segment_size(size):
	def segment_size(flo):
		data = flo.read(size)
		return data, len(data) == size
	return segment_size
def make_segment_time(seconds, read_size = None):
	if read_size == None:
		read_size = 1024
	def segment_time(flo):
		ts_start = time.time()
		ts_now = ts_start
		buffer = b''
		while ts_now - ts_start < seconds:
			data = flo.read(read_size)
			buffer += data
			if len(data) < read_size:
				return buffer, False
			ts_now = time.time()
		return buffer, True
	return segment_time
		
def stream(id, nacl_sign_key, flo, flo_read_method = None):
	'''Stream a file like object to ipfs pubsub id
	   id: channel to stream on
	   nacl_sign_key: bytes of data used to sign posts
	   flo: a file like object to stream
	   flo_read_method(optional): a function that returns bytes of flo to stream and a bool indicationg the end of the stream
	   	flo_read_method(flo) -> stream_bytes, stream_more
	   '''
	if flo_read_method == None:
		flo_read_method = make_segment_size(2 ** 18)
	link_list = []
	node = {
		'bytes':'',
		}
	ipfs_command = [
		'ipfs',
		'dag',
		'put',
		]
	publish_channel = 'video/{}'.format(id)
	
	output_exp = re.compile(r'([^\n ]+)\n')
	print('streaming to video/{}'.format(id))
	run_flag = True
	index = 0
	try:
		while run_flag:
			data, run_flag = flo_read_method(flo)
			node['bytes'] = 'm' + base64.b64encode(data).decode()
			json_node = json.dumps(node)
			debug > 1 and dprint('publishing segment:', len(data), 'bytes')
			ipfs_proc = subprocess.Popen(ipfs_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
			ipfs_out, ipfs_error = ipfs_proc.communicate(json_node.encode())
			
			output_match = output_exp.match(ipfs_out.decode())		
			if output_match:
				segment_id = output_match[1]
				link_data = {
					'Cid': {
						'/': segment_id,
						},
					'Name':'',
					'Size':len(data),
					}
				link_list.append(link_data)
				pub_message = 'segment {} {}'.format(index, segment_id)
				index += 1
				publish_ok, publish_err = ipfs_publish(publish_channel, id, nacl_sign_key, pub_message)
				
				if publish_ok:
					dprint('publish:',  pub_message)
				else:
					dprint('publish fail:', publish_err, pub_message)
			else:
				print('unable to publish message:', ipfs_out, ipfs_error)
	except KeyboardInterrupt:
		pass
	#quick note on the linked video file
	#ideally this would be compatible with ipfs get
	#unfortunately the get format is some sort of protobuf based thing
	#and the ipfs group is moving away from it to the ipld format
	#doubly unfortunatly there is no current file specification for the new ipld format
	#so this is all speculation... but it still beats trying to intigrate protobufs
	final_video = {
		'links':link_list,
		'data':''
		}
	ipfs_proc = subprocess.Popen(ipfs_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
	ipfs_out, ipfs_error = ipfs_proc.communicate(json.dumps(final_video).encode())
	full_hash = ipfs_out.decode()
	print('full video is at {}'.format(full_hash) )
	publish_ok, publish_err = ipfs_publish(publish_channel, id, nacl_sign_key, 'full {}'.format(full_hash) )
	
	
			
if __name__ == '__main__':
	#
	if len(sys.argv) > 1:
		if sys.argv[1] == 'get' and len(sys.argv) > 2:
			get_linked_data(sys.argv[2], sys.stdout.buffer)
		else:
			watch_id = sys.argv[1]
			if os.path.exists(os.path.join('public', watch_id)):
				subscribe(sys.argv[1])
			else:
				print('no public key found for {}: unable to stream'.format(watch_id))
	else:
		default_stream()