#!/usr/bin/env python
# -*- coding: utf-8 -*-

from twisted.trial import unittest
from framework import run

from twisted.internet import reactor, protocol
from twisted.protocols import loopback
from twisted.internet import defer

from compass.ussd import USSServerProtocol
from compass.ussc import USSClientProtocol, USSClientFactory
from compass.uss.protocols import byteprotocol
from compass.uss.message import usspmsg
import struct

dResult = unittest.deferredResult
loopbackFunc = loopback.loopback

def packmsg(name, items):
    m = usspmsg.USSPMessage()
    m.setMsgName(name)
    for f,v in items:
        m.body.setField(f, v)
    return m.packed()

def getpacklen(packet):
    (length,) = struct.unpack('>I', packet[:4])
    return length

def unpackmsg(packet):
    messages = []
    while packet:
        length = getpacklen(packet)
        p = packet[:length]
        packet = packet[length:]
        m = usspmsg.USSPMessage()
        m.loadMessage(p)
        messages.append(m)
    return messages

class DummyUSSClientProtocol(byteprotocol.ByteMessageProtocol):
    packet=''

    def __init__(self):
        self.conn = defer.Deferred()

    def connectionMade(self):
        self.conn.callback(None)

    def dataReceived(self, data):
        self.packet = data
        self.transport.loseConnection()

class ConnectServerTestCase(unittest.TestCase):
    sendmsg = ('connect', (
            ('system_id', 'sys001'),
            ('auth_source', 'xyb4567890123456'),
            ('version', 0x0001),
            ('time_stamp', '20041109'),
            ))
    recvmsg = ('connect_resp', (
            ('status', 0x0001),
            ('version', 0x0001),
            ))

    def testServer(self):
        server = USSServerProtocol()
        client = DummyUSSClientProtocol()
        msg = packmsg(self.sendmsg[0], self.sendmsg[1])

        def sendamsg(result):
            client.sendData(msg)

        client.conn.addCallback(sendamsg)
        loopbackFunc(server, client)

        (m,) = unpackmsg(client.packet)
        self.assertEquals(m.msgname, self.recvmsg[0])
        for i,v in self.recvmsg[1]:
            self.assertEquals(m.body.fields[i], v)

class MailCounterServerTestCase(ConnectServerTestCase):
    sendmsg = ('mail_counter', (
            ('uid', 'mail76'),
            ))
    recvmsg = ('mail_counter_resp', (
            ('uid', 'mail76'),
            ('number', 2028),
            ))

class TerminateServerTestCase(ConnectServerTestCase):
    sendmsg = ('terminate', (
            ))
    recvmsg = ('terminate_resp', (
            ))

class TestUSSServerProtocol(USSServerProtocol):
    packets = []
    def __init__(self):
        USSServerProtocol.__init__(self)
        self.conn = defer.Deferred()

    def connectionMade(self):
        USSServerProtocol.connectionMade(self)
        #self.conn.callback(None)

    def connectionLost(self, reason):
        pass

    def dataReceived(self, data):
        USSServerProtocol.dataReceived(self, data)
        self.packets.append(data)
        if unpackmsg(data)[-1].msgname == 'terminate':
            self.transport.loseConnection()

class ClientTestCase(unittest.TestCase):

    queue = (
            ('connect',
                (
                    ('auth_source', '1234567890123456'),
                    ('version', 256),
                    ('system_id', '123456'),
                    ('time_stamp', '12345678'),
                ),
            ),
            ('mail_counter',
                (
                    ('uid', '1'),
                ),
            ),
            ('terminate',
                (),
            ),
            )

    def tearDown(self):
        for dc in reactor.getDelayedCalls():
            dc.cancel()

    def testClient(self):
        server = TestUSSServerProtocol()
        factory = USSClientFactory()
        client = factory.buildProtocol(None)
        loopbackFunc(server, client)
        packets = []
        for p in server.packets:
            packets += unpackmsg(p)
        self.assertEquals(len(packets), len(self.queue))
        for m, n in zip(packets, self.queue):
            self.assertEquals(m.msgname, n[0])
            for i, v in n[1]:
                self.assertEquals(m.body.fields[i], v)

if '__main__' == __name__:
    run()
