#
# Copyright (c) 2016 Oracle and/or its affiliates. All rights reserved.
#
# This script configures GRE tunnels (in both directions) between a compute instance and a
# cloud CSG instance (Corente gateway) as part of Corente VPN solution.
#
# After the initial configuration,  it also starts a copy of itself in background to monitor
# health of the tunnel and re-establish the tunnel if connectivity is lost.
# 
# All references to CSG and gateway in this script are to a "Cloud CSG" instance.
#
# Since opc-init invokes the pre-bootstrap script using bash, this script is a shell script 
# with embedded python code. For the same reason, shebang won't work.
#
# Refer to Corente VPN documentation for details on how to use this script.
#

CONFIGURE_SCRIPT=$0
MONITOR_SCRIPT=/usr/bin/opc-corente-monitor
MONITOR_LOG=/var/log/opc-compute/opc-corente-monitor.log

python - $* <<END

import sys
import subprocess
from optparse import OptionParser
import socket, struct
import re
import time

# import bash variables
MONITOR_SCRIPT = "$MONITOR_SCRIPT"
CONFIGURE_SCRIPT = "$CONFIGURE_SCRIPT"
MONITOR_LOG = "$MONITOR_LOG"

# Monitor log rotate interval - in seconds
MONITOR_LOG_ROTATE_INTERVAL = 36000

# For logging
PROG_NAME = "oc-config-corente-tunnel"
PROG_VERSION = "1.5"

METADATA_ENDPOINT = "http://192.0.0.192/latest/attributes/userdata/corente-tunnel-args"

# Timeouts for curl - in seconds
CURL_CONNECT_TIMEOUT = 3
CURL_MAX_TIME = 5

# Default GRE interface name
GRE_INTERFACE = "gre1"

# For validating the subnet values
MAX_CIDR_MASK = 32

# ping defaults - in seconds except count
PING_COUNT = 3
PING_TIMEOUT = 2 
PING_INTERVAL = 10  

# ping thresholds - in seconds except count
MIN_PING_COUNT = 2
MIN_PING_TIMEOUT = 1
MIN_PING_INTERVAL = 3

# Value of this key should match the key value on CSG. This key does not
# serve any purpose (like security) other than identifying the tunnel in case
# of multiple GRE tunnels from this instance.
TUNNEL_KEY = "1234"

def run_shell_command(logmsg, cmdline, abort_on_error=False, background=False):
   try:
     print "### %s" % (logmsg)
     print "Command: %s" % (cmdline)
     proc = subprocess.Popen(cmdline, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=True)
     if not background:
       output = proc.communicate()[0]
       print "Output:\n%s" % (output)
       if abort_on_error and proc.returncode != 0:
         raise Exception("Error: Command failed with exit code: %d\n" % (proc.returncode))
       if not abort_on_error and proc.returncode != 0:
         print "Warning: Command failed with exit code: %d\n" % (proc.returncode)
       return (output)
     return ' '
   except OSError:
     return ("Error running command %s" % cmdline)

def process_args(input_args, check_for_instance_addr=False):
   parser = OptionParser()
   parser.add_option("-c", "--csg-hostname", dest="csg_hostname", help="Hostname of CSG instance")
   parser.add_option("-i", "--csg-instance-address", dest="csg_instance_address", help="CSG instance address")
   parser.add_option("-t", "--csg-tunnel-address", dest="csg_tunnel_address", help="GRE tunnel address of CSG instance")
   parser.add_option("-l", "--local-tunnel-address", dest="local_tunnel_address", help="GRE tunnel address of this instance")
   parser.add_option("-s", "--onprem-subnets", dest="onprem_subnets", help="List of on-premises networks (CIDR) participating in VPN")
   parser.add_option("-p", "--ping-count", dest="ping_count", help="Number of pings of CSG done as part of monitoring")
   parser.add_option("-n", "--ping-interval", dest="ping_interval", help="Interval between pings of CSG done as part of monitoring")
   parser.add_option("-w", "--ping-timeout", dest="ping_timeout", help="Timeout for each ping of CSG done as part of monitoring")
   (options, args) = parser.parse_args(input_args.split())
   validate_args(options, parser, args, check_for_instance_addr)
   return options

def isvalid_ip(address):
   try: 
     socket.inet_aton(address)
     return True
   except:
     return False

def isvalid_subnet_cidr(subnet):
    match = re.search(r"^(\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})/(\d{1,2})$", subnet)
    if not match:
       return False
    if not isvalid_ip(match.group(1)):
       return False
    if int(match.group(2)) > MAX_CIDR_MASK:
       return False
    return True

def validate_args(options, parser, args, check_for_instance_addr=False):
   local_tunnel_address = options.local_tunnel_address
   csg_tunnel_address = options.csg_tunnel_address
   csg_hostname = options.csg_hostname
   csg_instance_address = options.csg_instance_address
   onprem_subnets = options.onprem_subnets
   ping_count = options.ping_count
   ping_interval = options.ping_interval
   ping_timeout = options.ping_timeout

   if (local_tunnel_address == None or 
        csg_tunnel_address == None or 
        csg_hostname == None or 
        onprem_subnets == None):
     parser.error("Required arg is missing\n")
   if (check_for_instance_addr and csg_instance_address != None):
     parser.error("CSG instance address is not to be specified")
   if len(args) != 0:
     parser.error("Incorrect number of arguments\n")
   if not isvalid_ip(local_tunnel_address):
     parser.error("Invalid local tunnel address %s\n" % (local_tunnel_address))
   if not isvalid_ip(csg_tunnel_address): 
     parser.error("Invalid CSG tunnel address %s\n" % (csg_tunnel_address))
   if ping_count != None and int(ping_count) < MIN_PING_COUNT:
     parser.error("Invalid ping count %s\n" % (ping_count))
   if ping_timeout != None and int(ping_timeout) < MIN_PING_TIMEOUT:
     parser.error("Invalid ping timeout %s\n" % (ping_timeout))
   if ping_interval != None and int(ping_interval) < MIN_PING_INTERVAL:
     parser.error("Invalid ping interval %s\n" % (ping_interval))
   onprem_subnets_list = onprem_subnets.split(",")
   for onprem_subnet in onprem_subnets_list:
     if not isvalid_subnet_cidr(onprem_subnet):
        parser.error("Invalid onprem subnet %s\n" % (onprem_subnet))

def configure_tunnel_local(local_instance_addr, options, csg_instance_addr):
   local_tunnel_address = options.local_tunnel_address
   csg_tunnel_address = options.csg_tunnel_address
   onprem_subnets = options.onprem_subnets

   # add a new tunnel
   cmdline = "ip tunnel add %s mode gre local %s key %s" % (GRE_INTERFACE, local_instance_addr, TUNNEL_KEY)
   run_shell_command("Adding a new GRE tunnel", cmdline)

   # set address of local end of the tunnel
   cmdline = "ip address add %s dev %s" % (local_tunnel_address, GRE_INTERFACE)
   run_shell_command("Setting address of local end of the tunnel", cmdline)

   # bring the tunnel up
   cmdline = "ip link set up dev %s" % (GRE_INTERFACE)
   run_shell_command("Bringing the tunnel up", cmdline)

   # add the neighbor, the Corente gateway with its tunnel address and private IP address
   cmdline = "ip neighbor add %s lladdr %s dev %s" % (csg_tunnel_address, csg_instance_addr, GRE_INTERFACE)
   run_shell_command("Configuring CSG as neighbor", cmdline)

   # add a route to the Corente gateway side of the tunnel through the GRE interface
   cmdline = "ip route add %s dev %s" % (csg_tunnel_address, GRE_INTERFACE)
   run_shell_command("Adding route to CSG", cmdline)

   # add a route to the on-prem subnet via the Corente gateway through the GRE tunnel
   onprem_subnets_list = onprem_subnets.split(",")
   for onprem_subnet in onprem_subnets_list:
     cmdline = "ip route add %s via %s" % (onprem_subnet, csg_tunnel_address)
     run_shell_command("Adding route to on-prem subnet %s" % (onprem_subnet), cmdline)

def configure_tunnel_csg(local_instance_addr, options):
   local_tunnel_address = options.local_tunnel_address
   csg_hostname = options.csg_hostname

   # Call CSG's API endpoint for configuring its side of the tunnel
   url = "http://%s/cgi-bin/gretunnel?instanceip=%s\&greip=%s" % (csg_hostname, local_instance_addr, local_tunnel_address)
   cmdline = "curl -s -S --connect-timeout %d --max-time %d %s" % (CURL_CONNECT_TIMEOUT, CURL_MAX_TIME, url)
   run_shell_command("Calling CSG API endpoint to configure CSG side of the tunnel", cmdline)

def configure_tunnel(options, csg_instance_addr):
   try:
     # read our instance IP address
     print
     cmdline = "ip addr ls eth0 | grep inet | grep -v inet6 | awk '{print \$2}' | awk -F/ '{print \$1}'"
     output = run_shell_command("Reading local instance address", cmdline, True)
     local_instance_addr = output.splitlines()[0]
     if not isvalid_ip(local_instance_addr):
       raise Exception("Error: Instance address %s is invalid\n" % (local_instance_addr))

     # configure local end of GRE tunnel
     configure_tunnel_local(local_instance_addr, options, csg_instance_addr)

     # configure CSG end of GRE tunnel
     configure_tunnel_csg(local_instance_addr, options)
     return 0
   except Exception, e:
     print e
     return 1

def teardown_tunnel(csg_tunnel_address, csg_instance_address):
   cmdline = "ip neighbor del %s lladdr %s dev %s" % (csg_tunnel_address, csg_instance_address, GRE_INTERFACE)
   run_shell_command("Deleting neighbor entry", cmdline)
   cmdline = "ip tunnel del %s" % (GRE_INTERFACE)
   run_shell_command("Deleting local tunnel interface", cmdline)

def check_health(options):
   ping_count = int(options.ping_count) if options.ping_count != None else PING_COUNT
   ping_interval = int(options.ping_interval) if options.ping_interval != None else PING_INTERVAL
   ping_timeout = int(options.ping_timeout) if options.ping_timeout != None else PING_TIMEOUT
   ping_fail_count = 0
   while True:
     time.sleep(ping_interval)
     try:
       cmdline = "ping -I %s -c 1 -W %d %s" % (GRE_INTERFACE, ping_timeout, options.csg_tunnel_address)
       args = run_shell_command("Pinging CSG tunnel end", cmdline, True)
       ping_fail_count = 0
     except Exception, e:
       ping_fail_count += 1
       if (ping_fail_count == ping_count):
         break
       else:
         continue

def main_opc_init():
   print "### %s %s BEGIN\n" % (time.ctime(), PROG_NAME)
   print "%s version=%s\n" % (PROG_NAME, PROG_VERSION)

   # Fetch args from userdata
   cmdline = "curl -s -S --connect-timeout %d --max-time %d %s" % (CURL_CONNECT_TIMEOUT, CURL_MAX_TIME, METADATA_ENDPOINT)
   args = run_shell_command("Fetching arguments from userdata", cmdline, True)

   # process and validate arguments
   options = process_args(args, True)

   configret = 1
   # lookup CSG instance address
   try:
     cmdline = "dig +short %s" % (options.csg_hostname)
     print ''
     output = run_shell_command("Looking up CSG hostname in DNS", cmdline, True)
     output_lines = output.splitlines()
     if len(output_lines) > 0:
       csg_instance_addr = output_lines[0]
       # configure both ends of tunnel
       configret = configure_tunnel(options, csg_instance_addr)
     else:
       print "Error: DNS lookup of %s failed.\n" % (options.csg_hostname)
   except Exception, e:
       print "Error: DNS lookup of %s failed.\n" % (options.csg_hostname)

   # copy ourselves and start monitor in background
   cmdline = "cp -f %s %s" % (CONFIGURE_SCRIPT, MONITOR_SCRIPT)
   run_shell_command("Copying to monitor script", cmdline, True)
   cmdline = "chmod 755 %s" % (MONITOR_SCRIPT)
   run_shell_command("Changing monitor script permissions", cmdline, True)
   cmdline = "bash %s %s --csg-instance-address=%s &" % (MONITOR_SCRIPT, args, csg_instance_addr)
   run_shell_command("Starting monitor script", cmdline, False, True)

   print "\n### %s %s END\n" % (time.ctime(), PROG_NAME)
   return configret

def main_monitor():
   # redirect output to log
   monitorlog = open(MONITOR_LOG, 'w', 0)
   sys.stdout = monitorlog
   sys.stderr = monitorlog
   prev_log_rot_time = time.time()

   print "### %s %s BEGIN\n" % (time.ctime(), MONITOR_SCRIPT)

   # process and validate arguments
   options = process_args(' '.join(sys.argv[1:]))

   csg_instance_addr = options.csg_instance_address

   while True:
     print "### %s %s iteration BEGIN\n" % (time.ctime(), MONITOR_SCRIPT)

     # perform health check
     check_health(options)

     # health check failed, reconfigure tunnel
     print "### %s Tunnel liveness failure detected\n" % (time.ctime())
     if csg_instance_addr != None:
       print "### %s Tearing down existing tunnel \n" % (time.ctime())
       teardown_tunnel(options.csg_tunnel_address, csg_instance_addr)

     try:
       # lookup CSG instance address
       cmdline = "dig +short %s" % (options.csg_hostname)
       print ''
       output = run_shell_command("Looking up CSG hostname in DNS", cmdline, True)
       output_lines = output.splitlines()
       new_csg_instance_addr = None
       if len(output_lines) > 0:
         new_csg_instance_addr = output_lines[0]
         print "### %s Reestablishing the tunnel with CSG %s\n" % (time.ctime(), new_csg_instance_addr)
         configure_tunnel(options, new_csg_instance_addr)
       else:
         print "### %s DNS lookup of %s failed. Will try again in next iteration\n" % (time.ctime(), options.csg_hostname)
     except Exception, e:
       print "### %s DNS lookup of %s failed. Will try again in next iteration\n" % (time.ctime(), options.csg_hostname)
       
     csg_instance_addr = new_csg_instance_addr

     # truncate and reopen log
     if time.time() - prev_log_rot_time > MONITOR_LOG_ROTATE_INTERVAL:
       monitorlog.close()
       monitorlog = open(MONITOR_LOG, 'w', 0)
       sys.stdout = monitorlog
       sys.stderr = monitorlog
       prev_log_rot_time = time.time()
       print "### %s LOG ROTATED\n" % (time.ctime())

     print "### %s %s iteration END\n" % (time.ctime(), MONITOR_SCRIPT)


if __name__ == "__main__":
   if len(sys.argv) == 1:
     # Called from opc-init
     sys.exit(main_opc_init())
   else:
     # Called as background monitor 
     sys.exit(main_monitor())
     
END
