#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-2.0-or-later

from topotato.v1 import *
from topotato.network import DropHost

import logging
import asyncio
from topotato.livescapy import LiveScapy

from scapy.supersocket import SuperSocket
from scapy.config import conf

conf.use_2byte_asn = False

from scapy.contrib.bgp import (
    BGP,
    BGPHeader,
    BGPOpen,
    BGPKeepAlive,
)
from scapy.layers.l2 import Ether
from scapy.layers.inet import (
    IP,
    TCP,
)

"""
Trigger BGP collision resolution
"""


@topology_fixture()
def topology(topo):
    """
    [ dut ]
      |
    { lan }
      |
    [ peer ]
    """


class FRRConfigured(RouterFRR):
    zebra = """
    #% extends "boilerplate.conf"
    #% block main
    #% endblock
    """

    bgpd = """
    #% extends "boilerplate.conf"
    #% block main
    router bgp 65001
     no bgp ebgp-requires-policy
     neighbor {{ routers.peer.ifaces[0].ip4[0].ip }} remote-as external
    !
    #% endblock
    """


class Setup(TopotatoNetwork, topo=topology):
    dut: FRRConfigured
    peer: DropHost


@Setup.trait
async def foobar(net):
    _logger = logging.getLogger(__name__)

    _scapy: SuperSocket = net.scapys["lan"]
    lscapy: LiveScapy = _scapy.live

    dut = net.network.routers["dut"]
    peer = net.network.routers["peer"]

    common_hdr = Ether(src=peer.iface_to("lan").macaddr, dst=dut.iface_to("lan").macaddr)
    common_hdr /= IP(src=str(peer.iface_to("lan").ip4[0].ip), dst=str(dut.iface_to("lan").ip4[0].ip))

    async with lscapy.observe() as obs:
        # wait for incoming FRR SYN

        conn1_syn = await obs([(TCP, lambda tcp: tcp.flags.S)])
        _logger.info("RX SYN: %r", conn1_syn.pkt)
        conn1_tcp = conn1_syn.pkt.getlayer(TCP)

        # establish outgoing BGP session while "blocking" SYN

        conn2_syn = common_hdr.copy()
        conn2_syn /= TCP(sport=31337, dport=179, seq=0xdeadf00d, ack=0, flags="S")
        _logger.info("TX SYN: %r", conn2_syn)
        obs.send(conn2_syn)

        conn2_synack = await obs([(TCP, lambda tcp: tcp.flags.S and tcp.flags.A)])
        _logger.info("RX SYNACK: %r", conn2_synack.pkt)
        conn2_tcp = conn2_synack.pkt.getlayer(TCP)

        conn2_ack = common_hdr.copy()
        conn2_ack /= TCP(sport=31337, dport=179, seq=0xdeadf00d + 1, ack=conn2_tcp.seq + 1, flags="A")
        _logger.info("TX ACK: %r", conn2_ack)
        obs.send(conn2_ack)

        conn2_open = await obs([(BGPHeader, lambda x: True)])
        _logger.info("RX OPEN: %r", conn2_open.pkt)
        conn2_olen = len(conn2_open.pkt.getlayer(TCP).payload)

        # now resume incoming FRR connection

        conn1_synack = common_hdr.copy()
        conn1_synack /= TCP(sport=179, dport=conn1_tcp.sport, seq=0x1337cafe, ack=conn1_tcp.seq + 1, flags="SA")
        _logger.info("TX SYNACK: %r", conn1_synack)
        obs.send(conn1_synack)

        conn1_ack = await obs([(TCP, lambda tcp: tcp.flags.A)])
        _logger.info("RX ACK: %r", conn1_ack.pkt)

        conn1_open = await obs([(TCP, lambda tcp: tcp.flags.A)])
        _logger.info("RX OPEN: %r", conn1_open.pkt)
        conn1_olen = len(conn1_open.pkt.getlayer(TCP).payload)

        # we now have 2 OPENs from FRR - shoot out 2 OPENs from ourselves

        my_open = BGPHeader() / BGPOpen(my_as=65002, bgp_id="10.255.0.2")

        conn1_txopen = common_hdr.copy()
        conn1_txopen /= TCP(sport=179, dport=conn1_tcp.sport, seq=0x1337cafe + 1, ack=conn1_tcp.seq + conn1_olen + 1, flags="PA")
        conn1_txopen /= my_open.copy()
        _logger.info("TX OPEN: %r", conn1_txopen)
        obs.send(conn1_txopen)

        conn2_txopen = common_hdr.copy()
        conn2_txopen /= TCP(sport=31337, dport=179, seq=0xdeadf00d + 1, ack=conn2_tcp.seq + conn2_olen + 1, flags="PA")
        conn2_txopen /= my_open.copy()
        _logger.info("TX OPEN: %r", conn2_txopen)
        obs.send(conn2_txopen)

        # still in Connect... go into Established

        conn1_txopen = common_hdr.copy()
        conn1_txopen /= TCP(sport=179, dport=conn1_tcp.sport, seq=0x1337cafe + 1 + len(my_open), ack=conn1_tcp.seq + conn1_olen + 1, flags="PA")
        conn1_txopen /= BGPHeader()
        _logger.info("TX KA: %r", conn1_txopen)
        obs.send(conn1_txopen)

        conn2_txopen = common_hdr.copy()
        conn2_txopen /= TCP(sport=31337, dport=179, seq=0xdeadf00d + 1 + len(my_open), ack=conn2_tcp.seq + conn2_olen + 1, flags="PA")
        conn2_txopen /= BGPHeader()
        _logger.info("TX KA: %r", conn2_txopen)
        obs.send(conn2_txopen)

        while tpkt := await obs([(TCP, lambda tcp: True)]):
            _logger.info("RX: %r", tpkt.pkt)


class BGPCollide(TestBase, AutoFixture, setup=Setup):
    @topotatofunc
    def bgp_converge(self, topo, dut, peer):
        expected = {
            "peers": {
                str(peer.iface_to("lan").ip4[0].ip): {
                    "state": "Established",
                    "peerState": "OK",
                }
            }
        }
        yield from AssertVtysh.make(
            dut,
            "bgpd",
            f"show bgp ipv4 unicast summary json",
            maxwait=5.0,
            compare=expected,
        )